diff --git a/sim/algo/tdmpc/src/algorithm/__init__.py b/sim/algo/tdmpc/src/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sim/algo/tdmpc/src/algorithm/helper.py b/sim/algo/tdmpc/src/algorithm/helper.py new file mode 100644 index 00000000..28d8093f --- /dev/null +++ b/sim/algo/tdmpc/src/algorithm/helper.py @@ -0,0 +1,764 @@ +import glob +import os +import pickle +import re +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from numpy.linalg import norm +from scipy.spatial import distance +from torch import distributions as pyd +from torch.distributions.utils import _standard_normal + +__REDUCE__ = lambda b: "mean" if b else "none" + + +def l1(pred, target, reduce=False): + """Computes the L1-loss between predictions and targets.""" + return F.l1_loss(pred, target, reduction=__REDUCE__(reduce)) + + +def mse(pred, target, reduce=False): + """Computes the MSE loss between predictions and targets.""" + return F.mse_loss(pred, target, reduction=__REDUCE__(reduce)) + + +def bce(pred, target, logits=True, reduce=False): + """Computes the BCE loss between predictions and targets.""" + if logits: + return F.binary_cross_entropy_with_logits(pred, target, reduction=__REDUCE__(reduce)) + return F.binary_cross_entropy(pred, target, reduction=__REDUCE__(reduce)) + + +def l1_quantile(y_true, y_pred, quantile=0.3): + """ + Compute the quantile loss. + + Args: + y_true (torch.Tensor): Ground truth values + y_pred (torch.Tensor): Predicted values + quantile (float): Quantile to compute, must be between 0 and 1 + + Returns: + torch.Tensor: Quantile loss + """ + errors = y_true - y_pred + loss = torch.max((quantile - 1) * errors, quantile * errors) + return torch.mean(loss) + + +def threshold_l2_expectile(diff, threshold=1e-2, expectile=0.99, reduce=False): + weight = torch.where(torch.abs(diff) > threshold, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def l2_expectile(diff, expectile=0.7, reduce=False): + weight = torch.where(diff > 0, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def mse_expectile(pred, target, expectile=0.7, reduce=False): + diff = pred - target + weight = torch.where(diff > 0, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def _get_out_shape(in_shape, layers): + """Utility function. Returns the output shape of a network for a given input shape.""" + x = torch.randn(*in_shape).unsqueeze(0) + return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape + + +def gaussian_logprob(eps, log_std): + """Compute Gaussian log probability.""" + residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1) + + +def squash(mu, pi, log_pi): + """Apply squashing function.""" + mu = torch.tanh(mu) + pi = torch.tanh(pi) + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + return mu, pi, log_pi + + +def orthogonal_init(m): + """Orthogonal layer initialization.""" + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def ema(m, m_target, tau): + """Update slow-moving average of online network (target network) at rate tau.""" + with torch.no_grad(): + for p, p_target in zip(m.parameters(), m_target.parameters()): + p_target.data.lerp_(p.data, tau) + + +def set_requires_grad(net, value): + """Enable/disable gradients for a given (sub)network.""" + for param in net.parameters(): + param.requires_grad_(value) + + +def linear_schedule(schdl, step): + """ + Outputs values following a linear decay schedule. + Adapted from https://github.com/facebookresearch/drqv2 + """ + try: + return float(schdl) + except ValueError: + match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) + if match: + init, final, duration = [float(g) for g in match.groups()] + mix = np.clip(step / duration, 0.0, 1.0) + return (1.0 - mix) * init + mix * final + raise NotImplementedError(schdl) + + +class TruncatedNormal(pyd.Normal): + """Utility class implementing the truncated normal distribution.""" + + def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): + super().__init__(loc, scale, validate_args=False) + self.low = low + self.high = high + self.eps = eps + + def _clamp(self, x): + clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) + x = x - x.detach() + clamped_x.detach() + return x + + def sample(self, clip=None, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + eps *= self.scale + if clip is not None: + eps = torch.clamp(eps, -clip, clip) + x = self.loc + eps + return self._clamp(x) + + +class NormalizeImg(nn.Module): + """Normalizes pixel observations to [0,1) range.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.div(255.0) + + +class Flatten(nn.Module): + """Flattens its input to a (batched) vector.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.view(x.size(0), -1) + + +def enc(cfg): + """Returns a TOLD encoder.""" + pixels_enc_layers, state_enc_layers = None, None + if cfg.modality in {"pixels", "all"}: + C = int(3 * cfg.frame_stack) + pixels_enc_layers = [ + NormalizeImg(), + nn.Conv2d(C, cfg.num_channels, 7, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), + nn.ReLU(), + ] + out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers) + pixels_enc_layers.extend( + [ + Flatten(), + nn.Linear(np.prod(out_shape), cfg.latent_dim), + nn.LayerNorm(cfg.latent_dim), + nn.Sigmoid(), + ] + ) + if cfg.modality == "pixels": + return ConvExt(nn.Sequential(*pixels_enc_layers)) + if cfg.modality in {"state", "all"}: + state_dim = cfg.obs_shape[0] if cfg.modality == "state" else cfg.obs_shape["state"][0] + state_enc_layers = [ + nn.Linear(state_dim, cfg.enc_dim), + nn.LayerNorm(cfg.enc_dim), + nn.ELU(), + nn.Linear(cfg.enc_dim, cfg.enc_dim), + nn.LayerNorm(cfg.enc_dim), + nn.ELU(), + nn.Linear(cfg.enc_dim, cfg.latent_dim), + nn.LayerNorm(cfg.latent_dim), + nn.Tanh(), + ] + if cfg.modality == "state": + return nn.Sequential(*state_enc_layers) + else: + raise NotImplementedError + + encoders = {} + for k in cfg.obs_shape: + if k == "state": + encoders[k] = nn.Sequential(*state_enc_layers) + elif k.endswith("rgb"): + encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers)) + else: + raise NotImplementedError + return Multiplexer(nn.ModuleDict(encoders)) + + +def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.ReLU(), layer_norm=True): + """Returns an MLP.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] + for i in range(len(mlp_dim) - 1): + layers += [ + nn.Linear(mlp_dim[i], mlp_dim[i + 1]), + nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), + act_fn, + ] + layers += [nn.Linear(mlp_dim[-1], out_dim)] + return nn.Sequential(*layers) + + +def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): + """Returns a dynamics network.""" + return nn.Sequential( + mlp(in_dim, mlp_dim, out_dim, act_fn), + nn.LayerNorm(out_dim), + act_fn, + ) + + +def q(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True): + """Returns a Q-function that uses Layer Normalization.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] + for i in range(len(mlp_dim) - 1): + layers += [ + nn.Linear(mlp_dim[i], mlp_dim[i + 1]), + nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), + act_fn, + ] + layers += [nn.Linear(mlp_dim[-1], 1)] + return nn.Sequential(*layers) + + +def v(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True): + """Returns a Q-function that uses Layer Normalization.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] + for i in range(len(mlp_dim) - 1): + layers += [ + nn.Linear(mlp_dim[i], mlp_dim[i + 1]), + nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), + act_fn, + ] + layers += [[nn.Linear(mlp_dim[-1], 1)]] + return nn.Sequential(*layers) + + +def aug(cfg): + if cfg.modality == "state": + return nn.Identity() + else: + augs = {} + for k in cfg.obs_shape: + if k == "state": + augs[k] = nn.Identity() + else: + raise NotImplementedError + return Multiplexer(nn.ModuleDict(augs)) + + +class ConvExt(nn.Module): + def __init__(self, conv): + super().__init__() + self.conv = conv + + def forward(self, x): + if x.ndim > 4: + batch_shape = x.shape[:-3] + out = self.conv(x.view(-1, *x.shape[-3:])) + out = out.view(*batch_shape, *out.shape[1:]) + else: + out = self.conv(x) + return out + + +class Multiplexer(nn.Module): + + def __init__(self, choices): + super().__init__() + self.choices = choices + + def forward(self, x, key=None): + if isinstance(x, dict): + if key is not None: + return self.choices[key](x) + return {k: self.choices[k](_x) for k, _x in x.items()} + return self.choices(x) + + +class Episode(object): + """Storage object for a single episode.""" + + def __init__(self, cfg, init_obs): + self.cfg = cfg + self.device = torch.device(cfg.buffer_device) + self.capacity = int(cfg.max_episode_length // cfg.action_repeat) + if cfg.modality in {"pixels", "state"}: + dtype = torch.float32 if cfg.modality == "state" else torch.uint8 + self.next_obses = torch.empty( + (cfg.num_envs, self.capacity, *init_obs.shape[1:]), + dtype=dtype, + device=self.device, + ) + self.obses = torch.empty( + (cfg.num_envs, self.capacity, *init_obs.shape[1:]), + dtype=dtype, + device=self.device, + ) + self.obses[:, 0] = init_obs.clone().to(self.device, dtype=dtype) + elif cfg.modality == "all": + self.obses = {} + for k, v in init_obs.items(): + assert k in {"rgb", "state"} + dtype = torch.float32 if k == "state" else torch.uint8 + self.next_obses[k] = torch.empty( + (cfg.num_envs, self.capacity, *v.shape[1:]), dtype=dtype, device=self.device + ) + self.obses[k] = torch.empty( + (cfg.num_envs, self.capacity, *v.shape[1:]), dtype=dtype, device=self.device + ) + self.obses[k][:, 0] = v.clone().to(self.device, dtype=dtype) + else: + raise ValueError + self.actions = torch.empty( + (cfg.num_envs, self.capacity, cfg.action_dim), + dtype=torch.float32, + device=self.device, + ) + self.rewards = torch.empty( + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.float32, + device=self.device, + ) + self.dones = torch.empty( + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.bool, + device=self.device, + ) + self.successes = torch.empty( + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.bool, + device=self.device, + ) + self.masks = torch.zeros( + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.float32, + device=self.device, + ) + self.cumulative_reward = torch.tensor([0.0] * cfg.num_envs) + self.done = torch.tensor([False] * cfg.num_envs) + self.success = torch.tensor([False] * cfg.num_envs) + self._idx = 0 + + def __len__(self): + return self._idx + + @property + def episode_length(self): + num_dones = self.dones[:, : self._idx].sum().item() + if num_dones > 0: + return float(self._idx) * self.cfg.num_envs / num_dones + return float(self._idx) + + @classmethod + def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): + """Constructs an episode from a trajectory.""" + + if cfg.modality in {"pixels", "state"}: + episode = cls(cfg, obses[0]) + episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) + elif cfg.modality == "all": + episode = cls(cfg, {k: v[0] for k, v in obses.items()}) + for k, v in obses.items(): + episode.obses[k][1:] = torch.tensor(obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device) + else: + raise NotImplementedError + episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) + episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) + episode.dones = ( + torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) + if dones is not None + else torch.zeros_like(episode.dones) + ) + episode.masks = ( + torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) + if masks is not None + else torch.ones_like(episode.masks) + ) + episode.cumulative_reward = torch.sum(episode.rewards) + episode.done = True + episode._idx = cfg.episode_length + return episode + + @property + def first(self): + return len(self) == 0 + + @property + def full(self): + return len(self) == self.capacity + + @property + def buffer_capacity(self): + return self.capacity + + def __add__(self, transition): + self.add(*transition) + return self + + def add(self, obs, action, reward, done, timeouts, success=False): + if isinstance(obs, dict): + for k, v in obs.items(): + if self._idx == self.capacity - 1: + self.next_obses[k][:, self._idx] = v.clone().to(self.obses[k].device, dtype=self.obses[k].dtype) + elif self._idx < self.capacity - 1: + self.obses[k][:, self._idx + 1] = v.clone().to(self.obses[k].device, dtype=self.obses[k].dtype) + self.next_obses[k][:, self._idx] = self.obses[k][:, self._idx + 1].clone() + else: + if self._idx == self.capacity - 1: + self.next_obses[:, self._idx] = obs.clone().to(self.obses.device, dtype=self.obses.dtype) + elif self._idx < self.capacity - 1: + self.obses[:, self._idx + 1] = obs.clone().to(self.obses.device, dtype=self.obses.dtype) + self.next_obses[:, self._idx] = self.obses[:, self._idx + 1].clone() + self.actions[:, self._idx] = action.detach().cpu() + self.rewards[:, self._idx] = reward.detach().cpu() + self.dones[:, self._idx] = done.detach().cpu() + self.masks[:, self._idx] = 1.0 - timeouts.detach().cpu().float() # TODO + self.cumulative_reward += reward.detach().cpu() + self.done = done.detach().cpu() + self.success = torch.logical_or(self.success, success.detach().cpu()) + self.successes[:, self._idx] = torch.tensor(self.success).to(self.device) + self._idx += 1 + + +class ReplayBuffer: + """ + Storage and sampling functionality for training TD-MPC / TOLD. + The replay buffer is stored in GPU memory when training from state. + Uses prioritized experience replay by default.""" + + def __init__(self, cfg, dataset=None): + self.cfg = cfg + self.buffer_device = torch.device(cfg.buffer_device) + self.device = torch.device(cfg.device) + self.batch_size = self.cfg.batch_size + + print("Replay buffer device: ", self.buffer_device) + print("Replay buffer sample device: ", self.device) + + if dataset is not None: + self.capacity = max(dataset["rewards"].shape[0], cfg.max_offline_buffer_size) + print("Offline dataset size: ", dataset["rewards"].shape[0]) + else: + self.capacity = max(cfg.train_steps, cfg.max_buffer_size) + + print("Maximum capacity of the buffer is: ", self.capacity) + + if cfg.modality in {"pixels", "state"}: + dtype = torch.float32 if cfg.modality == "state" else torch.uint8 + # Note self.obs_shape always has single frame, which is different from cfg.obs_shape + self.obs_shape = cfg.obs_shape if cfg.modality == "state" else (3, *cfg.obs_shape[-2:]) + self._obs = torch.empty((self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device) + self._next_obs = torch.empty((self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device) + elif cfg.modality == "all": + self.obs_shape = {} + self._obs, self._next_obs = {}, {} + for k, v in cfg.obs_shape.items(): + assert k in {"rgb", "state"} + dtype = torch.float32 if k == "state" else torch.uint8 + self.obs_shape[k] = v if k == "state" else (3, *v[-2:]) + self._obs[k] = torch.empty( + (self.capacity, *self.obs_shape[k]), + dtype=dtype, + device=self.buffer_device, + ) + self._next_obs[k] = self._obs[k].clone() + else: + raise ValueError + + self._action = torch.empty( + (self.capacity, cfg.action_dim), + dtype=torch.float32, + device=self.buffer_device, + ) + self._reward = torch.empty((self.capacity,), dtype=torch.float32, device=self.buffer_device) + self._mask = torch.empty((self.capacity,), dtype=torch.float32, device=self.buffer_device) + self._done = torch.empty((self.capacity,), dtype=torch.bool, device=self.buffer_device) + self._success = torch.empty((self.capacity,), dtype=torch.bool, device=self.buffer_device) + self._priorities = torch.ones((self.capacity,), dtype=torch.float32, device=self.buffer_device) + self.ep_len = int(self.cfg.max_episode_length // self.cfg.action_repeat) + self._eps = 1e-6 + self._full = False + self.idx = 0 + self._sampling_idx = 0 + if dataset is not None: + self.init_from_offline_dataset(dataset) + + self._aug = aug(cfg) + + def init_from_offline_dataset(self, dataset): + assert self.idx == 0 and not self._full + n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent) + + def copy_data(dst, src, n): + assert isinstance(dst, dict) == isinstance(src, dict) + if isinstance(dst, dict): + for k in dst: + copy_data(dst[k], src[k], n) + else: + dst[:n] = torch.from_numpy(src[:n]) + + copy_data(self._obs, dataset["observations"], n_transitions) + copy_data(self._next_obs, dataset["next_observations"], n_transitions) + copy_data(self._action, dataset["actions"], n_transitions) + if self.cfg.task.startswith("xarm"): + # success = self._calc_sparse_success(dataset['success']) + # copy_data(self._reward, success.astype(np.float32), n_transitions) + if self.cfg.sparse_reward: + copy_data(self._reward, dataset["success"].astype(np.float32), n_transitions) + copy_data(self._success, dataset["success"], n_transitions) + else: + copy_data(self._reward, dataset["rewards"], n_transitions) + copy_data(self._mask, dataset["masks"], n_transitions) + copy_data(self._done, dataset["dones"], n_transitions) + self.idx = (self.idx + n_transitions) % self.capacity + self._full = n_transitions >= self.capacity + mask_idxs = np.array([n_transitions - i for i in range(1, self.cfg.horizon)]) + # _, episode_ends, _ = get_trajectory_boundaries_and_returns(dataset) + # mask_idxs = np.array([np.array(episode_ends) - i for i in range(1, self.cfg.horizon)]).T.flatten() + mask_idxs = np.clip(mask_idxs, 0, n_transitions - 1) + self._priorities[mask_idxs] = 0 + + def __add__(self, episode: Episode): + self.add(episode) + return self + + def add(self, episode: Episode): + idxs = torch.arange(self.idx, self.idx + self.cfg.num_envs * self.ep_len) % self.capacity + self._sampling_idx = (self.idx + self.cfg.num_envs * self.ep_len) % self.capacity + mask_copy = episode.masks.clone() + mask_copy[:, episode._idx - self.cfg.horizon :] = 0.0 + if self.cfg.modality in {"pixels", "state"}: + self._obs[idxs] = ( + episode.obses.flatten(0, 1) if self.cfg.modality == "state" else episode.obses[:, -3:].flatten(0, 1) + ) + self._next_obs[idxs] = ( + episode.next_obses.flatten(0, 1) + if self.cfg.modality == "state" + else episode.next_obses[:, -3:].flatten(0, 1) + ) + elif self.cfg.modality == "all": + for k, v in episode.obses.items(): + assert k in {"rgb", "state"} + assert k in self._obs + assert k in self._next_obs + if k == "rgb": + self._obs[k][idxs] = episode.obses[k][:-1, -3:].flatten(0, 1) + self._next_obs[k][idxs] = episode.obses[k][1:, -3:].flatten(0, 1) + else: + self._obs[k][idxs] = episode.obses[k][:-1].flatten(0, 1) + self._next_obs[k][idxs] = episode.obses[k][1:].flatten(0, 1) + self._action[idxs] = episode.actions.flatten(0, 1) + self._reward[idxs] = episode.rewards.flatten(0, 1) + self._mask[idxs] = mask_copy.flatten(0, 1) + self._done[idxs] = episode.dones.flatten(0, 1) + self._success[idxs] = episode.successes.flatten(0, 1) + if self._full: + max_priority = self._priorities.max().to(self.device).item() + else: + max_priority = 1.0 if self.idx == 0 else self._priorities[: self.idx].max().to(self.device).item() + mask = torch.arange(self.ep_len) > self.ep_len - self.cfg.horizon + mask = torch.cat([mask] * self.cfg.num_envs) + new_priorities = torch.full((self.ep_len * self.cfg.num_envs,), max_priority, device=self.buffer_device) + new_priorities[mask] = 0 + new_priorities = new_priorities * self._mask[idxs] + self._priorities[idxs] = new_priorities + self._full = self._full or (self.idx + self.ep_len * self.cfg.num_envs > self.capacity) + if episode.full: + self.idx = (self.idx + self.ep_len * self.cfg.num_envs) % self.capacity + + def _set_bs(self, bs): + self.batch_size = bs + + def update_priorities(self, idxs, priorities): + self._priorities[idxs.to(self.buffer_device)] = priorities.squeeze(1).to(self.buffer_device) + self._eps + + def _get_obs(self, arr, idxs, bs=None, frame_stack=None): + if isinstance(arr, dict): + return {k: self._get_obs(v, idxs, bs=bs, frame_stack=frame_stack) for k, v in arr.items()} + if arr.ndim <= 2: # if self.cfg.modality == 'state': + return arr[idxs].cuda(self.device) + obs = torch.empty( + ( + self.cfg.batch_size if bs is None else bs, + 3 * self.cfg.frame_stack if frame_stack is None else 3 * frame_stack, + *arr.shape[-2:], + ), + dtype=arr.dtype, + device=torch.device(self.device), + ) + obs[:, -3:] = arr[idxs].cuda(self.device) + _idxs = idxs.clone() + mask = torch.ones_like(_idxs, dtype=torch.bool) + for i in range(1, self.cfg.frame_stack if frame_stack is None else frame_stack): + mask[_idxs % self.cfg.episode_length == 0] = False + _idxs[mask] -= 1 + obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda(self.device) + return obs.float() + + def save(self, buffer_fp): + data = { + "obs": self._obs.cpu(), + "next_obs": self._next_obs.cpu(), + "action": self._action.cpu(), + "reward": self._reward.cpu(), + "mask": self._mask.cpu(), + "done": self._done.cpu(), + "priorities": self._priorities.cpu(), + } + torch.save(data, buffer_fp) + + def load(self, buffer_fp): + data = torch.load(buffer_fp) + n_transitions = data["obs"].shape[0] + + if n_transitions >= self.capacity: + self._obs = self._obs[-self.capacity :] + self._next_obs = data["next_obs"][-self.capacity :].to(self.buffer_device) + self._action = data["action"][-self.capacity :].to(self.buffer_device) + self._reward = data["reward"][-self.capacity :].to(self.buffer_device) + self._mask = data["mask"][-self.capacity :].to(self.buffer_device) + self._done = data["done"][-self.capacity :].to(self.buffer_device) + self._priorities = data["priorities"][-self.capacity :].to(self.buffer_device) + self.idx = 0 + self._full = True + else: + self._obs[:n_transitions] = self._obs + self._next_obs[:n_transitions] = data["next_obs"].to(self.buffer_device) + self._action[:n_transitions] = data["action"].to(self.buffer_device) + self._reward[:n_transitions] = data["reward"].to(self.buffer_device) + self._mask[:n_transitions] = data["mask"].to(self.buffer_device) + self._done[:n_transitions] = data["done"].to(self.buffer_device) + self._priorities[:n_transitions] = data["priorities"].to(self.buffer_device) + self.idx = n_transitions + self._full = False + + def sample(self, bs=None): + probs = (self._priorities if self._full else self._priorities[: self._sampling_idx]) ** self.cfg.per_alpha + probs /= probs.sum() + total = len(probs) + if torch.isnan(self._priorities).any(): + print(torch.isnan(self._priorities).any()) + print(torch.where(torch.isnan(self._priorities))) + idxs = ( + torch.from_numpy( + np.random.choice( + total, + self.cfg.batch_size if bs is None else bs, + p=probs.cpu().numpy(), + replace=((not self._full) or (self.cfg.batch_size > self.capacity)), + ) + ).to(self.buffer_device) + % self.capacity + ) + weights = (total * probs[idxs]) ** (-self.cfg.per_beta) + weights /= weights.max() + + idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)]) % self.capacity + + obs = self._aug(self._get_obs(self._obs, idxs, bs=bs)) + next_obs = [self._aug(self._get_obs(self._next_obs, _idxs, bs=bs)) for _idxs in idxs_in_horizon] + if isinstance(next_obs[0], dict): + next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]} + else: + next_obs = torch.stack(next_obs) + action = self._action[idxs_in_horizon] + reward = self._reward[idxs_in_horizon] + mask = self._mask[idxs_in_horizon] + done = torch.logical_not(self._done[idxs_in_horizon]).float() + + if not action.is_cuda: + action, reward, done, idxs, weights = ( + action.cuda(self.device), + reward.cuda(self.device), + done.cuda(self.device), + idxs.cuda(self.device), + weights.cuda(self.device), + ) + return ( + obs, + next_obs, + action, + reward.unsqueeze(2), + done.unsqueeze(2), + idxs, + weights, + ) diff --git a/sim/algo/tdmpc/src/algorithm/tdmpc.py b/sim/algo/tdmpc/src/algorithm/tdmpc.py new file mode 100644 index 00000000..09028197 --- /dev/null +++ b/sim/algo/tdmpc/src/algorithm/tdmpc.py @@ -0,0 +1,280 @@ +import numpy as np +import torch +import torch.nn as nn +from copy import deepcopy +from dataclasses import asdict +import sim.algo.tdmpc.src.algorithm.helper as h + + +class TOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self._encoder = h.enc(cfg) + self._dynamics = h.mlp(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim, cfg.latent_dim) + self._reward = h.mlp(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim, 1) + self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim) + self._Qs = nn.ModuleList([h.q(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim) for _ in range(cfg.num_q)]) + self.apply(h.orthogonal_init) + for m in [self._reward, *self._Qs]: + m[-1].weight.data.fill_(0) + m[-1].bias.data.fill_(0) + + def track_q_grad(self, enable=True): + """Utility function. Enables/disables gradient tracking of Q-networks.""" + for m in self._Qs: + h.set_requires_grad(m, enable) + + def h(self, obs): + """Encodes an observation into its latent representation (h).""" + return self._encoder(obs) + + def next(self, z, a): + """Predicts next latent state (d) and single-step reward (R).""" + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x) + + def pi(self, z, std=0): + """Samples an action from the learned policy (pi).""" + mu = self.cfg.max_clip_actions * torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(mu) * std + return h.TruncatedNormal(mu, std, low=-self.cfg.max_clip_actions, high=self.cfg.max_clip_actions).sample( + clip=0.3 + ) + return mu + + def Q(self, z, a): + """Predict state-action value (Q).""" + x = torch.cat([z, a], dim=-1) + Qs = torch.stack([self._Qs[i](x) for i in range(self.cfg.num_q)], dim=0) + return Qs + + +class TDMPC: + """Implementation of TD-MPC learning + inference.""" + + def __init__(self, cfg): + self.cfg = cfg + self.device = torch.device(cfg.device) + self.std = h.linear_schedule(cfg.std_schedule, 0) + self.model = TOLD(cfg).to(self.device) + self.model_target = deepcopy(self.model) + self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=3 * self.cfg.lr) + self.aug = nn.Identity() + self.model.eval() + self.model_target.eval() + + def state_dict(self, step=0, env_step=0): + """Retrieve state dict of TOLD model, including slow-moving target network.""" + return { + "model": self.model.state_dict(), + "model_target": self.model_target.state_dict(), + "config": asdict(self.cfg), + "step": step, + "env_step": env_step, + } + + def save(self, fp): + """Save state dict of TOLD model to filepath.""" + torch.save(self.save_dict(), fp) + + def load(self, fp): + """Load a saved state dict from filepath into current agent.""" + d = torch.load(fp) + self.model.load_state_dict(d["model"]) + self.model_target.load_state_dict(d["model_target"]) + return d["step"], d["env_step"] + + @torch.no_grad() + def estimate_value(self, z, actions, horizon): + """Estimate value of a trajectory starting at latent state z and executing given actions.""" + G, discount = 0, 1 + for t in range(horizon): + z, reward = self.model.next(z, actions[t]) + G += discount * reward + discount *= self.cfg.discount + G += discount * torch.min(self.model.Q(z, self.model.pi(z, self.cfg.min_std)), dim=0)[0] + return G + + @torch.no_grad() + def plan(self, obs, eval_mode=False, step=None, t0=True): + """ + Plan next action using TD-MPC inference. + obs: raw input observation. + eval_mode: uniform sampling and action noise is disabled during evaluation. + step: current time step. determines e.g. planning horizon. + t0: whether current step is the first step of an episode. + """ + clip_actions = h.linear_schedule(self.cfg.clip_actions, step=step) + # Seed steps + if step < self.cfg.seed_steps and not eval_mode: + return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_( + -clip_actions, clip_actions + ) + + # Sample policy trajectories + obs = obs.clone().to(self.device, dtype=torch.float32).unsqueeze(1) + horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) + num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) + if num_pi_trajs > 0: + pi_actions = torch.empty(horizon, self.cfg.num_envs, num_pi_trajs, self.cfg.action_dim, device=self.device) + z = self.model.h(obs).repeat(1, num_pi_trajs, 1) + for t in range(horizon): + pi_actions[t] = self.model.pi(z, self.cfg.min_std) + z, _ = self.model.next(z, pi_actions[t]) + + # Initialize state and parameters + z = self.model.h(obs).repeat(1, self.cfg.num_samples + num_pi_trajs, 1) + mean = torch.zeros(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) + std = 1.5 * torch.ones(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) * clip_actions + + if isinstance(t0, bool) and t0 and hasattr(self, "_prev_mean") and self._prev_mean.shape[0] > 1: + _prev_h = self._prev_mean.shape[0] - 1 + mean[:_prev_h] = self._prev_mean[1:] + elif torch.is_tensor(t0) and t0.any() and hasattr(self, "_prev_mean") and self._prev_mean.shape[0] > 1: + _prev_h = self._prev_mean.shape[0] - 1 + mean[:_prev_h] = self._prev_mean[1:] + + # Iterate CEM + for i in range(self.cfg.iterations): + actions = torch.clamp( + mean.unsqueeze(2) + + std.unsqueeze(2) + * torch.randn(horizon, self.cfg.num_envs, self.cfg.num_samples, self.cfg.action_dim, device=std.device), + -clip_actions, + clip_actions, + ) + if num_pi_trajs > 0: + actions = torch.cat([actions, pi_actions], dim=-2) + + # Compute elite actions + value = self.estimate_value(z, actions, horizon).nan_to_num_(0) + elite_idxs = torch.topk(value.squeeze(-1), self.cfg.num_elites, dim=-1).indices + elite_value, elite_actions = value.squeeze(-1).gather(-1, elite_idxs), actions.gather( + -2, elite_idxs.unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim) + ) + + # Update parameters + max_value = elite_value.max(1, keepdim=True)[0] + score = torch.exp(self.cfg.temperature * (elite_value - max_value)) + score /= score.sum(1, keepdim=True) + _mean = torch.sum(score.unsqueeze(0).unsqueeze(-1) * elite_actions, dim=-2) / ( + score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9 + ) + _std = torch.sqrt( + torch.sum(score.unsqueeze(0).unsqueeze(-1) * (elite_actions - _mean.unsqueeze(2)) ** 2, dim=-2) + / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9) + ) + _std = _std.clamp_(self.std, 1.5 * clip_actions) + mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std + + # Outputs + select_indices = torch.multinomial(score, 1) + actions = elite_actions.gather( + -2, select_indices.unsqueeze(0).unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim) + ).squeeze(-2) + self._prev_mean = mean + mean, std = actions[0], _std[0] + a = mean + if not eval_mode: + a = h.TruncatedNormal(a, std, low=-clip_actions, high=clip_actions).sample() + return a + + def update_pi(self, zs): + """Update policy using a sequence of latent states.""" + self.pi_optim.zero_grad(set_to_none=True) + self.model.track_q_grad(False) + + # Loss is a weighted sum of Q-values + pi_loss = 0 + for t, z in enumerate(zs): + a = self.model.pi(z, self.cfg.min_std) + Q = torch.min(self.model.Q(z, a), dim=0)[0] + pi_loss += -Q.mean() * (self.cfg.rho**t) + + pi_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False) + self.pi_optim.step() + self.model.track_q_grad(True) + return pi_loss.item() + + @torch.no_grad() + def _td_target(self, next_obs, reward, mask=1.0): + """Compute the TD-target from a reward and the observation at the following time step.""" + next_z = self.model.h(next_obs) + td_target = ( + reward + + self.cfg.discount + * mask + * torch.min(self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)), dim=0)[0] + ) + return td_target + + def update(self, replay_buffer, step): + """Main update function. Corresponds to one iteration of the TOLD model learning.""" + obs, next_obses, action, reward, mask, idxs, weights = replay_buffer.sample() + self.optim.zero_grad(set_to_none=True) + self.std = h.linear_schedule(self.cfg.std_schedule, step) + self.model.train() + + # Representation + z = self.model.h(self.aug(obs)) + zs = [z.detach()] + + loss_mask = torch.ones_like(mask[0], device=self.device) + + consistency_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0 + for t in range(self.cfg.horizon): + if t > 0: + loss_mask = loss_mask * mask[t - 1] + # Predictions + Qs = self.model.Q(z, action[t]) + z, reward_pred = self.model.next(z, action[t]) + with torch.no_grad(): + next_obs = self.aug(next_obses[t]) + next_z = self.model_target.h(next_obs) + td_target = self._td_target(next_obs, mask[t], reward[t]) + zs.append(z.detach()) + + # Losses + rho = self.cfg.rho**t + consistency_loss += loss_mask[t] * rho * torch.mean(h.mse(z, next_z), dim=1, keepdim=True) + reward_loss += loss_mask[t] * rho * h.mse(reward_pred, reward[t]) + for i in range(self.cfg.num_q): + value_loss += loss_mask[t] * rho * h.mse(Qs[i], td_target) + priority_loss += loss_mask[t] * rho * h.l1(Qs[i], td_target) + + # Optimize model + total_loss = ( + self.cfg.consistency_coef * consistency_loss.clamp(max=1e4) + + self.cfg.reward_coef * reward_loss.clamp(max=1e4) + + self.cfg.value_coef * value_loss.clamp(max=1e4) + ) + weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) + weighted_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False + ) + self.optim.step() + replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach()) + + if step % self.cfg.update_freq == 0: + # Update policy + target network + pi_loss = self.update_pi(zs) + h.ema(self.model, self.model_target, self.cfg.tau) + + self.model.eval() + return { + "consistency_loss": float(consistency_loss.mean().item()), + "reward_loss": float(reward_loss.mean().item()), + "value_loss": float(value_loss.mean().item()), + "pi_loss": pi_loss if step % self.cfg.update_freq == 0 else 0.0, + "total_loss": float(total_loss.mean().item()), + "weighted_loss": float(weighted_loss.mean().item()), + "grad_norm": float(grad_norm), + } diff --git a/sim/algo/tdmpc/src/logger.py b/sim/algo/tdmpc/src/logger.py new file mode 100644 index 00000000..cdc6f9b7 --- /dev/null +++ b/sim/algo/tdmpc/src/logger.py @@ -0,0 +1,189 @@ +import sys +import os +import datetime +import re +import cv2 +import numpy as np +import torch +import pandas as pd +from isaacgym import gymapi +from termcolor import colored + + +CONSOLE_FORMAT = [ + ("episode", "E", "int"), + ("step", "S", "int"), + ("env_step", "ES", "int"), + ("episode_reward", "R", "float"), + ("mean_episode_length", "MEL", "float"), + ("total_time", "T", "time"), +] +# ('consistency_loss', 'CL', 'float'), ('value_loss', 'PL', 'float'), ('pi_loss', 'PL', 'float'), ('total_loss', 'L', 'float')] +AGENT_METRICS = ["consistency_loss", "reward_loss", "value_loss", "total_loss", "weighted_loss", "pi_loss", "grad_norm"] + + +def make_dir(dir_path): + """Create directory if it does not already exist.""" + try: + os.makedirs(dir_path) + except OSError: + pass + return dir_path + + +def print_run(cfg, reward=None): + """Pretty-printing of run information. Call at start of training.""" + prefix, color, attrs = " ", "green", ["bold"] + + def limstr(s, maxlen=32): + return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s + + def pprint(k, v): + print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v)) + + kvs = [ + ("task", cfg.task), + ("train steps", f"{int(cfg.train_steps*cfg.action_repeat):,}"), + ("observations", "x".join([str(s) for s in cfg.obs_shape])), + ("actions", cfg.action_dim), + ("experiment", cfg.exp_name), + ] + if reward is not None: + kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))) + w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 + div = "-" * w + print(div) + for k, v in kvs: + pprint(k, v) + print(div) + + +def cfg_to_group(cfg, return_list=False): + """Return a wandb-safe group name for logging. Optionally returns group name as list.""" + lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] + return lst if return_list else "-".join(lst) + + +class VideoRecorder: + """Utility class for logging evaluation videos.""" + + def __init__(self, root_dir, wandb, render_size=384, fps=15): + self.save_dir = (root_dir / "eval_video") if root_dir else None + self._wandb = wandb + self.render_size = render_size + self.fps = fps + self.frames = [] + self.enabled = False + + def init(self, env, h1, enabled=True): + self.frames = [] + self.enabled = self.save_dir and self._wandb and enabled + self.record(env, h1) + + def record(self, env, h1): + if self.enabled: + env.gym.fetch_results(env.sim, True) + env.gym.step_graphics(env.sim) + env.gym.render_all_camera_sensors(env.sim) + img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) + img = np.reshape(img, (120, 160, 4)) + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + self.frames.append(img) + + def save(self, step): + if self.enabled: + frames = np.stack(self.frames).transpose(0, 3, 1, 2) + self._wandb.log({"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")}, step=step) + + +class Logger(object): + """Primary logger object. Logs either locally or using wandb.""" + + def __init__(self, log_dir, cfg): + project, entity = cfg.wandb_project, cfg.wandb_entity + run_offline = not cfg.use_wandb or project == "none" or entity == "none" + + self._save_model = cfg.save_model + if not run_offline or self._save_model: + self._log_dir = make_dir(log_dir) + + if self._save_model: + self._model_dir = make_dir(self._log_dir / "models") + + self._group = cfg_to_group(cfg) + self._seed = cfg.seed + self._cfg = cfg + self._eval = [] + print_run(cfg) + if run_offline: + print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + self._wandb = None + else: + try: + os.environ["WANDB_SILENT"] = "true" + import wandb + + wandb.init( + project=project, + entity=entity, + name=str(cfg.seed), + group=self._group, + tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], + dir=self._log_dir, + ) + print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + self._wandb = wandb + except: + print(colored("Warning: failed to init wandb. Logs will be saved locally.", "yellow", attrs=["bold"])) + self._wandb = None + self._video = VideoRecorder(log_dir, self._wandb) if (self._wandb and cfg.save_video) else None + + @property + def video(self): + return self._video + + def save(self, agent, model_name="model.pt", step=None, env_step=None): + if self._save_model: + fp = self._model_dir / f"{model_name}" + torch.save(agent.state_dict(step=step, env_step=env_step), fp) + + def finish(self, agent, model_name="model.pt"): + if self._save_model: + fp = self._model_dir / f"{model_name}" + torch.save(agent.state_dict(), fp) + # if self._wandb: + # artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model') + # artifact.add_file(fp) + # self._wandb.log_artifact(artifact) + if self._wandb: + self._wandb.finish() + print_run(self._cfg, self._eval[-1][-1]) + + def _format(self, key, value, ty): + if ty == "int": + return f'{colored(key+":", "grey")} {int(value):,}' + elif ty == "float": + return f'{colored(key+":", "grey")} {value:.03f}' + elif ty == "time": + value = str(datetime.timedelta(seconds=int(value))) + return f'{colored(key+":", "grey")} {value}' + else: + raise f"invalid log format type: {ty}" + + def _print(self, d, category): + category = colored(category, "blue" if category == "train" else "green") + pieces = [f" {category:<14}"] + for k, disp_k, ty in CONSOLE_FORMAT: + pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}") + print(" ".join(pieces)) + + def log(self, d, category="train"): + assert category in {"train", "eval"} + if self._wandb is not None: + for k, v in d.items(): + self._wandb.log({category + "/" + k: v}, step=d["env_step"]) + if category == "eval": + keys = ["env_step", "episode_reward"] + self._eval.append(np.array([d[keys[0]], d[keys[1]]])) + pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / "eval.log", header=keys, index=None) + self._print(d, category) diff --git a/sim/play_tdmpc.py b/sim/play_tdmpc.py new file mode 100644 index 00000000..29670401 --- /dev/null +++ b/sim/play_tdmpc.py @@ -0,0 +1,220 @@ +"""Trains a humanoid to stand up.""" + +import argparse +import isaacgym +import os +import cv2 +import torch +from sim.envs import task_registry +from sim.utils.helpers import get_args +from sim.algo.tdmpc.src import logger +from sim.algo.tdmpc.src.algorithm.helper import Episode, ReplayBuffer +from sim.algo.tdmpc.src.algorithm.tdmpc import TDMPC +from dataclasses import dataclass, field +from isaacgym import gymapi +from typing import List +import time +import numpy as np +from pathlib import Path +import random +from datetime import datetime +from tqdm import tqdm + +torch.backends.cudnn.benchmark = True +__LOGS__ = "logs" + + +@dataclass +class TDMPC_DoraConfigs: + seed: int = 42 + task: str = "walk" + exp_name: str = "dora" + device: str = "cuda:0" + num_envs: int = 10 + episode_length: int = 100 + max_episode_length: int = 1000 + max_clip_actions: float = 1.0 + clip_actions: str = f"linear(1, {max_clip_actions}, 100000)" + + lr: float = 1e-3 + modality: str = "state" + enc_dim: int = 512 # 256 + mlp_dim = [512, 256] # [256, 256] + latent_dim: int = 100 + + iterations: int = 12 + num_samples: int = 512 + num_elites: int = 50 + mixture_coef: float = 0.05 + min_std: float = 0.05 + temperature: float = 0.5 + momentum: float = 0.1 + horizon: int = 5 + std_schedule: str = f"linear(0.5, {min_std}, 3000)" + horizon_schedule: str = f"linear(1, {horizon}, 2500)" + + batch_size: int = 1024 + max_buffer_size: int = 1000000 + reward_coef: float = 1 + value_coef: float = 0.5 + consistency_coef: float = 2 + rho: float = 0.5 + kappa: float = 0.1 + per_alpha: float = 0.6 + per_beta: float = 0.4 + grad_clip_norm: float = 10 + seed_steps: int = 750 + update_freq: int = 2 + tau: int = 0.01 + + discount: float = 0.99 + buffer_device: str = "cpu" + train_steps: int = int(1e6) + num_q: int = 3 + + action_repeat: int = 2 + eval_freq: int = 15000 + eval_freq_episode: int = 10 + eval_episodes: int = 1 + + save_model: bool = True + save_video: bool = False + eval_model: bool = False + save_buffer: bool = True + eval_freq_episode: int = 10 + eval_episodes: int = 1 + save_buffer_freq_episode: int = 50 + save_model_freq_episode: int = 10 + + use_wandb: bool = False + wandb_entity: str = "crajagopalan" + wandb_project: str = "xbot" + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +class VideoRecorder: + """Utility class for logging evaluation videos.""" + + def __init__(self, root_dir, render_size=384, fps=25): + self.save_dir = (root_dir / "eval_video") if root_dir else None + self.render_size = render_size + self.fps = fps + self.frames = [] + self.enabled = False + fourcc = cv2.VideoWriter_fourcc(*"MP4V") # type: ignore[attr-defined] + logger.make_dir(self.save_dir) + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + # Creates a directory to store videos. + dir = os.path.join(self.save_dir, now + ".mp4") + self.video = cv2.VideoWriter(dir, fourcc, float(fps), (1920, 1080)) + + def init(self, env, h1, enabled=True): + self.frames = [] + self.enabled = self.save_dir and enabled + self.record(env, h1) + + def record(self, env, h1): + env.gym.fetch_results(env.sim, True) + env.gym.step_graphics(env.sim) + env.gym.render_all_camera_sensors(env.sim) + img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) + img = np.reshape(img, (1080, 1920, 4)) + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + self.video.write(img) + + def save( + self, + ): + self.video.release() + + +def evaluate(test_env, agent, h1, step, video, action_repeat=1): + """Evaluate a trained agent and optionally save a video.""" + episode_rewards = [] + obs, privileged_obs = test_env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + dones, ep_reward, t = torch.tensor([False] * test_env.num_envs), torch.tensor([0.0] * test_env.num_envs), 0 + if video: + video.init(test_env, h1, enabled=True) + for i in tqdm(range(int(1000 // action_repeat))): + actions = agent.plan(state, eval_mode=True, step=step, t0=t == 0) + for _ in range(action_repeat): + obs, privileged_obs, rewards, dones, infos = test_env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + ep_reward += rewards.cpu() + t += 1 + if video: + video.record(test_env, h1) + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + episode_rewards.append(ep_reward) + if video: + video.save() + print(f"Timestep : {t} Episode Rewards - {torch.cat(episode_rewards).mean().item()}") + return torch.nanmean(torch.cat(episode_rewards)).item() + + +def play(args: argparse.Namespace) -> None: + """Training script for TD-MPC. Requires a CUDA-enabled device.""" + assert torch.cuda.is_available() + env_cfg, _ = task_registry.get_cfgs(name=args.task) + env, _ = task_registry.make_env(name=args.task, args=args) + + fp = "/home/guest/sim/logs/2024-09-17_00-45-33_walk_state_dora/models/tdmpc_policy_350.pt" + config = torch.load(fp)["config"] + tdmpc_cfg = TDMPC_DoraConfigs(**config) + env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) + + camera_properties = gymapi.CameraProperties() + camera_properties.width = 1920 + camera_properties.height = 1080 + h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) + camera_offset = gymapi.Vec3(3, -3, 1) + camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) + actor_handle = env.gym.get_actor_handle(env.envs[0], 0) + body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) + env.gym.attach_camera_to_body( + h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION + ) + + set_seed(tdmpc_cfg.seed) + work_dir = ( + Path().cwd() / __LOGS__ / f"{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}_{str(tdmpc_cfg.seed)}" + ) + + obs, privileged_obs = env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + + tdmpc_cfg.obs_shape = [state.shape[0]] + tdmpc_cfg.action_shape = env.num_actions + tdmpc_cfg.action_dim = env.num_actions + tdmpc_cfg.episode_length = 100 # int(env.max_episode_length // tdmpc_cfg.action_repeat) + tdmpc_cfg.num_envs = env.num_envs + + L = logger.Logger(work_dir, tdmpc_cfg) + log_dir = logger.make_dir(work_dir) + video = VideoRecorder(log_dir) + + agent = TDMPC(tdmpc_cfg) + + agent.load(fp) + step = 0 + episode_idx, start_time = 0, time.time() + if fp is not None: + episode_idx = int(fp.split(".")[0].split("_")[-1]) + step = episode_idx * tdmpc_cfg.episode_length + + # Log training episode + evaluate(env, agent, h1, step, video, tdmpc_cfg.action_repeat) + print("Testing completed successfully") + + +if __name__ == "__main__": + play(get_args()) diff --git a/sim/resources/stompymini/__init__.py b/sim/resources/stompymini/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/sim/resources/stompymini/robot.xml b/sim/resources/stompymini/robot.xml new file mode 100755 index 00000000..db5f3f38 --- /dev/null +++ b/sim/resources/stompymini/robot.xml @@ -0,0 +1,598 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sim/resources/stompymini/robot_fixed.xml b/sim/resources/stompymini/robot_fixed.xml new file mode 100755 index 00000000..aa4981d9 --- /dev/null +++ b/sim/resources/stompymini/robot_fixed.xml @@ -0,0 +1,961 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/sim/train_tdmpc.py b/sim/train_tdmpc.py new file mode 100644 index 00000000..d8391882 --- /dev/null +++ b/sim/train_tdmpc.py @@ -0,0 +1,253 @@ +"""Trains a humanoid to stand up.""" + +import argparse +import isaacgym +import torch +from sim.envs import task_registry +from sim.utils.helpers import get_args +from sim.algo.tdmpc.src import logger +from sim.algo.tdmpc.src.algorithm.helper import Episode, ReplayBuffer +from sim.algo.tdmpc.src.algorithm.tdmpc import TDMPC +from dataclasses import dataclass, field, asdict +from datetime import datetime +from isaacgym import gymapi +from typing import List +import time +import numpy as np +from pathlib import Path +import random + +torch.backends.cudnn.benchmark = True +__LOGS__ = "logs" + + +@dataclass +class TDMPC_DoraConfigs: + seed: int = 42 + task: str = "walk" + exp_name: str = "dora" + device: str = "cuda:0" + num_envs: int = 10 + max_clip_actions: float = 1.0 + clip_actions: str = f"linear(1, {max_clip_actions}, 100000)" + episode_length: int = 100 + max_episode_length: int = 1000 + + lr: float = 1e-3 + modality: str = "state" + enc_dim: int = 512 + mlp_dim = [512, 256] + latent_dim: int = 100 + + iterations: int = 12 + num_samples: int = 512 + num_elites: int = 50 + mixture_coef: float = 0.05 + min_std: float = 0.1 + temperature: float = 0.5 + momentum: float = 0.1 + horizon: int = 5 + std_schedule: str = f"linear(0.5, {min_std}, 200000)" + horizon_schedule: str = f"linear(1, {horizon}, 15000)" + + batch_size: int = 8192 + max_buffer_size: int = int(5e6) + reward_coef: float = 1 + value_coef: float = 0.75 + consistency_coef: float = 2 + rho: float = 0.75 + kappa: float = 0.1 + per_alpha: float = 0.6 + per_beta: float = 0.4 + grad_clip_norm: float = 50 + seed_steps: int = 500 + update_freq: int = 3 + tau: int = 0.05 + + discount: float = 0.99 + buffer_device: str = "cpu" + train_steps: int = int(1e6) + num_q: int = 2 + + action_repeat: int = 2 + + save_model: bool = True + save_video: bool = False + save_buffer: bool = True + eval_model: bool = False + eval_freq_episode: int = 10 + eval_episodes: int = 1 + save_buffer_freq_episode: int = 50 + save_model_freq_episode: int = 50 + + use_wandb: bool = False + wandb_entity: str = "crajagopalan" + wandb_project: str = "xbot" + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def evaluate(test_env, agent, h1, num_episodes, step, env_step, video, action_repeat=1): + """Evaluate a trained agent and optionally save a video.""" + episode_rewards = [] + for i in range(num_episodes): + obs, privileged_obs = test_env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + dones, ep_reward, t = torch.tensor([False]), 0, 0 + if video: + video.init(test_env, h1, enabled=(i == 0)) + while not dones[0].item(): + actions = agent.plan(state, eval_mode=True, step=step, t0=t == 0) + for _ in range(action_repeat): + obs, privileged_obs, rewards, dones, infos = test_env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + ep_reward += rewards[0] + if video: + video.record(test_env, h1) + t += 1 + episode_rewards.append(ep_reward) + if video: + video.save(env_step) + return torch.nanmean(torch.tensor(episode_rewards)).item() + + +def train(args: argparse.Namespace) -> None: + """Training script for TD-MPC. Requires a CUDA-enabled device.""" + assert torch.cuda.is_available() + env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) + env, _ = task_registry.make_env(name=args.task, args=args) + + tdmpc_cfg = TDMPC_DoraConfigs() + + if tdmpc_cfg.save_video: + env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) + + camera_properties = gymapi.CameraProperties() + camera_properties.width = 160 + camera_properties.height = 120 + h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) + camera_offset = gymapi.Vec3(3, -3, 1) + camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) + actor_handle = env.gym.get_actor_handle(env.envs[0], 0) + body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) + env.gym.attach_camera_to_body( + h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION + ) + + set_seed(tdmpc_cfg.seed) + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + work_dir = Path().cwd() / __LOGS__ / f"{now}_{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}" + + obs, privileged_obs = env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + + tdmpc_cfg.obs_shape = state.shape[1:] + tdmpc_cfg.action_shape = env.num_actions + tdmpc_cfg.action_dim = env.num_actions + episode_length = 100 + tdmpc_cfg.episode_length = episode_length + tdmpc_cfg.num_envs = env.num_envs + tdmpc_cfg.max_episode_length = int(env.max_episode_length) + tdmpc_cfg.max_clip_actions = env.cfg.normalization.clip_actions + tdmpc_cfg.clip_actions = f"linear(1, {env.cfg.normalization.clip_actions}, 50000)" + print(tdmpc_cfg.clip_actions) + + agent = TDMPC(tdmpc_cfg) + buffer = ReplayBuffer(tdmpc_cfg) + fp = None + + init_step, env_step = 0, 0 + episode_idx, start_time = 0, time.time() + + if fp is not None: + init_step, env_step = agent.load(fp) + if init_step is None: + episode_idx = int(fp.split(".")[0].split("_")[-1]) + init_step = episode_idx * tdmpc_cfg.episode_length + env_step = init_step * tdmpc_cfg.episode_length + episode = Episode(tdmpc_cfg, state) + + L = logger.Logger(work_dir, tdmpc_cfg) + + for step in range(init_step, tdmpc_cfg.train_steps + tdmpc_cfg.episode_length, tdmpc_cfg.episode_length): + if episode.full: + episode = Episode(tdmpc_cfg, state) + for i in range(tdmpc_cfg.episode_length): + actions = agent.plan(state, t0=episode.first, eval_mode=False, step=step) + original_state = state.clone() + total_rewards, total_dones, total_timeouts = [], [], [] + for _ in range(tdmpc_cfg.action_repeat): + obs, privileged_obs, rewards, dones, infos = env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + total_rewards.append(rewards) + total_dones.append(dones) + total_timeouts.append(infos["time_outs"]) + episode += ( + original_state, + actions, + torch.stack(total_rewards).sum(dim=0), + torch.stack(total_dones).any(dim=0), + torch.stack(total_timeouts).any(dim=0), + torch.stack(total_dones).any(dim=0), + ) + buffer += episode + + # Update model + train_metrics = {} + if step >= tdmpc_cfg.seed_steps: + num_updates = tdmpc_cfg.seed_steps if step == tdmpc_cfg.seed_steps else tdmpc_cfg.episode_length + for i in range(num_updates): + train_metrics.update(agent.update(buffer, step + i)) + + # Log training episode + episode_idx += 1 + env_step += int(tdmpc_cfg.episode_length * tdmpc_cfg.num_envs) + common_metrics = { + "episode": episode_idx, + "step": step + tdmpc_cfg.episode_length, + "env_step": env_step, + "total_time": time.time() - start_time, + "episode_reward": episode.cumulative_reward.sum().item() / env.num_envs, + "mean_episode_length": episode.episode_length, + } + train_metrics.update(common_metrics) + L.log(train_metrics, category="train") + + # Evaluate agent periodically + if tdmpc_cfg.save_model and episode_idx % tdmpc_cfg.save_model_freq_episode == 0: + L.save( + agent, + f"tdmpc_policy_{int(step // tdmpc_cfg.episode_length) + 1}.pt", + step + tdmpc_cfg.episode_length, + env_step, + ) + if tdmpc_cfg.save_buffer and episode_idx % tdmpc_cfg.save_buffer_freq_episode == 0: + buffer.save(str(work_dir / "buffer.pt")) + if tdmpc_cfg.eval_model and episode_idx % tdmpc_cfg.eval_freq_episode == 0: + common_metrics["episode_reward"] = evaluate( + env, + agent, + h1 if L.video is not None else None, + tdmpc_cfg.eval_episodes, + step, + env_step, + L.video, + tdmpc_cfg.action_repeat, + ) + L.log(common_metrics, category="eval") + + print("Training completed successfully") + + +if __name__ == "__main__": + # python -m sim.humanoid_gym.train + train(get_args())