diff --git a/sim/algo/tdmpc/src/algorithm/__init__.py b/sim/algo/tdmpc/src/algorithm/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/sim/algo/tdmpc/src/algorithm/helper.py b/sim/algo/tdmpc/src/algorithm/helper.py
new file mode 100644
index 00000000..28d8093f
--- /dev/null
+++ b/sim/algo/tdmpc/src/algorithm/helper.py
@@ -0,0 +1,764 @@
+import glob
+import os
+import pickle
+import re
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from numpy.linalg import norm
+from scipy.spatial import distance
+from torch import distributions as pyd
+from torch.distributions.utils import _standard_normal
+
+__REDUCE__ = lambda b: "mean" if b else "none"
+
+
+def l1(pred, target, reduce=False):
+ """Computes the L1-loss between predictions and targets."""
+ return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))
+
+
+def mse(pred, target, reduce=False):
+ """Computes the MSE loss between predictions and targets."""
+ return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))
+
+
+def bce(pred, target, logits=True, reduce=False):
+ """Computes the BCE loss between predictions and targets."""
+ if logits:
+ return F.binary_cross_entropy_with_logits(pred, target, reduction=__REDUCE__(reduce))
+ return F.binary_cross_entropy(pred, target, reduction=__REDUCE__(reduce))
+
+
+def l1_quantile(y_true, y_pred, quantile=0.3):
+ """
+ Compute the quantile loss.
+
+ Args:
+ y_true (torch.Tensor): Ground truth values
+ y_pred (torch.Tensor): Predicted values
+ quantile (float): Quantile to compute, must be between 0 and 1
+
+ Returns:
+ torch.Tensor: Quantile loss
+ """
+ errors = y_true - y_pred
+ loss = torch.max((quantile - 1) * errors, quantile * errors)
+ return torch.mean(loss)
+
+
+def threshold_l2_expectile(diff, threshold=1e-2, expectile=0.99, reduce=False):
+ weight = torch.where(torch.abs(diff) > threshold, expectile, (1 - expectile))
+ loss = weight * (diff**2)
+ reduction = __REDUCE__(reduce)
+ if reduction == "mean":
+ return torch.mean(loss)
+ elif reduction == "sum":
+ return torch.sum(loss)
+ return loss
+
+
+def l2_expectile(diff, expectile=0.7, reduce=False):
+ weight = torch.where(diff > 0, expectile, (1 - expectile))
+ loss = weight * (diff**2)
+ reduction = __REDUCE__(reduce)
+ if reduction == "mean":
+ return torch.mean(loss)
+ elif reduction == "sum":
+ return torch.sum(loss)
+ return loss
+
+
+def mse_expectile(pred, target, expectile=0.7, reduce=False):
+ diff = pred - target
+ weight = torch.where(diff > 0, expectile, (1 - expectile))
+ loss = weight * (diff**2)
+ reduction = __REDUCE__(reduce)
+ if reduction == "mean":
+ return torch.mean(loss)
+ elif reduction == "sum":
+ return torch.sum(loss)
+ return loss
+
+
+def _get_out_shape(in_shape, layers):
+ """Utility function. Returns the output shape of a network for a given input shape."""
+ x = torch.randn(*in_shape).unsqueeze(0)
+ return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape
+
+
+def gaussian_logprob(eps, log_std):
+ """Compute Gaussian log probability."""
+ residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True)
+ return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1)
+
+
+def squash(mu, pi, log_pi):
+ """Apply squashing function."""
+ mu = torch.tanh(mu)
+ pi = torch.tanh(pi)
+ log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
+ return mu, pi, log_pi
+
+
+def orthogonal_init(m):
+ """Orthogonal layer initialization."""
+ if isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight.data)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Conv2d):
+ gain = nn.init.calculate_gain("relu")
+ nn.init.orthogonal_(m.weight.data, gain)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+
+def ema(m, m_target, tau):
+ """Update slow-moving average of online network (target network) at rate tau."""
+ with torch.no_grad():
+ for p, p_target in zip(m.parameters(), m_target.parameters()):
+ p_target.data.lerp_(p.data, tau)
+
+
+def set_requires_grad(net, value):
+ """Enable/disable gradients for a given (sub)network."""
+ for param in net.parameters():
+ param.requires_grad_(value)
+
+
+def linear_schedule(schdl, step):
+ """
+ Outputs values following a linear decay schedule.
+ Adapted from https://github.com/facebookresearch/drqv2
+ """
+ try:
+ return float(schdl)
+ except ValueError:
+ match = re.match(r"linear\((.+),(.+),(.+)\)", schdl)
+ if match:
+ init, final, duration = [float(g) for g in match.groups()]
+ mix = np.clip(step / duration, 0.0, 1.0)
+ return (1.0 - mix) * init + mix * final
+ raise NotImplementedError(schdl)
+
+
+class TruncatedNormal(pyd.Normal):
+ """Utility class implementing the truncated normal distribution."""
+
+ def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
+ super().__init__(loc, scale, validate_args=False)
+ self.low = low
+ self.high = high
+ self.eps = eps
+
+ def _clamp(self, x):
+ clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
+ x = x - x.detach() + clamped_x.detach()
+ return x
+
+ def sample(self, clip=None, sample_shape=torch.Size()):
+ shape = self._extended_shape(sample_shape)
+ eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
+ eps *= self.scale
+ if clip is not None:
+ eps = torch.clamp(eps, -clip, clip)
+ x = self.loc + eps
+ return self._clamp(x)
+
+
+class NormalizeImg(nn.Module):
+ """Normalizes pixel observations to [0,1) range."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x.div(255.0)
+
+
+class Flatten(nn.Module):
+ """Flattens its input to a (batched) vector."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+def enc(cfg):
+ """Returns a TOLD encoder."""
+ pixels_enc_layers, state_enc_layers = None, None
+ if cfg.modality in {"pixels", "all"}:
+ C = int(3 * cfg.frame_stack)
+ pixels_enc_layers = [
+ NormalizeImg(),
+ nn.Conv2d(C, cfg.num_channels, 7, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2),
+ nn.ReLU(),
+ ]
+ out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers)
+ pixels_enc_layers.extend(
+ [
+ Flatten(),
+ nn.Linear(np.prod(out_shape), cfg.latent_dim),
+ nn.LayerNorm(cfg.latent_dim),
+ nn.Sigmoid(),
+ ]
+ )
+ if cfg.modality == "pixels":
+ return ConvExt(nn.Sequential(*pixels_enc_layers))
+ if cfg.modality in {"state", "all"}:
+ state_dim = cfg.obs_shape[0] if cfg.modality == "state" else cfg.obs_shape["state"][0]
+ state_enc_layers = [
+ nn.Linear(state_dim, cfg.enc_dim),
+ nn.LayerNorm(cfg.enc_dim),
+ nn.ELU(),
+ nn.Linear(cfg.enc_dim, cfg.enc_dim),
+ nn.LayerNorm(cfg.enc_dim),
+ nn.ELU(),
+ nn.Linear(cfg.enc_dim, cfg.latent_dim),
+ nn.LayerNorm(cfg.latent_dim),
+ nn.Tanh(),
+ ]
+ if cfg.modality == "state":
+ return nn.Sequential(*state_enc_layers)
+ else:
+ raise NotImplementedError
+
+ encoders = {}
+ for k in cfg.obs_shape:
+ if k == "state":
+ encoders[k] = nn.Sequential(*state_enc_layers)
+ elif k.endswith("rgb"):
+ encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers))
+ else:
+ raise NotImplementedError
+ return Multiplexer(nn.ModuleDict(encoders))
+
+
+def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.ReLU(), layer_norm=True):
+ """Returns an MLP."""
+ if isinstance(mlp_dim, int):
+ mlp_dim = [mlp_dim, mlp_dim]
+ layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn]
+ for i in range(len(mlp_dim) - 1):
+ layers += [
+ nn.Linear(mlp_dim[i], mlp_dim[i + 1]),
+ nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(),
+ act_fn,
+ ]
+ layers += [nn.Linear(mlp_dim[-1], out_dim)]
+ return nn.Sequential(*layers)
+
+
+def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()):
+ """Returns a dynamics network."""
+ return nn.Sequential(
+ mlp(in_dim, mlp_dim, out_dim, act_fn),
+ nn.LayerNorm(out_dim),
+ act_fn,
+ )
+
+
+def q(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True):
+ """Returns a Q-function that uses Layer Normalization."""
+ if isinstance(mlp_dim, int):
+ mlp_dim = [mlp_dim, mlp_dim]
+ layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn]
+ for i in range(len(mlp_dim) - 1):
+ layers += [
+ nn.Linear(mlp_dim[i], mlp_dim[i + 1]),
+ nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(),
+ act_fn,
+ ]
+ layers += [nn.Linear(mlp_dim[-1], 1)]
+ return nn.Sequential(*layers)
+
+
+def v(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True):
+ """Returns a Q-function that uses Layer Normalization."""
+ if isinstance(mlp_dim, int):
+ mlp_dim = [mlp_dim, mlp_dim]
+ layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn]
+ for i in range(len(mlp_dim) - 1):
+ layers += [
+ nn.Linear(mlp_dim[i], mlp_dim[i + 1]),
+ nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(),
+ act_fn,
+ ]
+ layers += [[nn.Linear(mlp_dim[-1], 1)]]
+ return nn.Sequential(*layers)
+
+
+def aug(cfg):
+ if cfg.modality == "state":
+ return nn.Identity()
+ else:
+ augs = {}
+ for k in cfg.obs_shape:
+ if k == "state":
+ augs[k] = nn.Identity()
+ else:
+ raise NotImplementedError
+ return Multiplexer(nn.ModuleDict(augs))
+
+
+class ConvExt(nn.Module):
+ def __init__(self, conv):
+ super().__init__()
+ self.conv = conv
+
+ def forward(self, x):
+ if x.ndim > 4:
+ batch_shape = x.shape[:-3]
+ out = self.conv(x.view(-1, *x.shape[-3:]))
+ out = out.view(*batch_shape, *out.shape[1:])
+ else:
+ out = self.conv(x)
+ return out
+
+
+class Multiplexer(nn.Module):
+
+ def __init__(self, choices):
+ super().__init__()
+ self.choices = choices
+
+ def forward(self, x, key=None):
+ if isinstance(x, dict):
+ if key is not None:
+ return self.choices[key](x)
+ return {k: self.choices[k](_x) for k, _x in x.items()}
+ return self.choices(x)
+
+
+class Episode(object):
+ """Storage object for a single episode."""
+
+ def __init__(self, cfg, init_obs):
+ self.cfg = cfg
+ self.device = torch.device(cfg.buffer_device)
+ self.capacity = int(cfg.max_episode_length // cfg.action_repeat)
+ if cfg.modality in {"pixels", "state"}:
+ dtype = torch.float32 if cfg.modality == "state" else torch.uint8
+ self.next_obses = torch.empty(
+ (cfg.num_envs, self.capacity, *init_obs.shape[1:]),
+ dtype=dtype,
+ device=self.device,
+ )
+ self.obses = torch.empty(
+ (cfg.num_envs, self.capacity, *init_obs.shape[1:]),
+ dtype=dtype,
+ device=self.device,
+ )
+ self.obses[:, 0] = init_obs.clone().to(self.device, dtype=dtype)
+ elif cfg.modality == "all":
+ self.obses = {}
+ for k, v in init_obs.items():
+ assert k in {"rgb", "state"}
+ dtype = torch.float32 if k == "state" else torch.uint8
+ self.next_obses[k] = torch.empty(
+ (cfg.num_envs, self.capacity, *v.shape[1:]), dtype=dtype, device=self.device
+ )
+ self.obses[k] = torch.empty(
+ (cfg.num_envs, self.capacity, *v.shape[1:]), dtype=dtype, device=self.device
+ )
+ self.obses[k][:, 0] = v.clone().to(self.device, dtype=dtype)
+ else:
+ raise ValueError
+ self.actions = torch.empty(
+ (cfg.num_envs, self.capacity, cfg.action_dim),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ self.rewards = torch.empty(
+ (
+ cfg.num_envs,
+ self.capacity,
+ ),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ self.dones = torch.empty(
+ (
+ cfg.num_envs,
+ self.capacity,
+ ),
+ dtype=torch.bool,
+ device=self.device,
+ )
+ self.successes = torch.empty(
+ (
+ cfg.num_envs,
+ self.capacity,
+ ),
+ dtype=torch.bool,
+ device=self.device,
+ )
+ self.masks = torch.zeros(
+ (
+ cfg.num_envs,
+ self.capacity,
+ ),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ self.cumulative_reward = torch.tensor([0.0] * cfg.num_envs)
+ self.done = torch.tensor([False] * cfg.num_envs)
+ self.success = torch.tensor([False] * cfg.num_envs)
+ self._idx = 0
+
+ def __len__(self):
+ return self._idx
+
+ @property
+ def episode_length(self):
+ num_dones = self.dones[:, : self._idx].sum().item()
+ if num_dones > 0:
+ return float(self._idx) * self.cfg.num_envs / num_dones
+ return float(self._idx)
+
+ @classmethod
+ def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None):
+ """Constructs an episode from a trajectory."""
+
+ if cfg.modality in {"pixels", "state"}:
+ episode = cls(cfg, obses[0])
+ episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device)
+ elif cfg.modality == "all":
+ episode = cls(cfg, {k: v[0] for k, v in obses.items()})
+ for k, v in obses.items():
+ episode.obses[k][1:] = torch.tensor(obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device)
+ else:
+ raise NotImplementedError
+ episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device)
+ episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device)
+ episode.dones = (
+ torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device)
+ if dones is not None
+ else torch.zeros_like(episode.dones)
+ )
+ episode.masks = (
+ torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device)
+ if masks is not None
+ else torch.ones_like(episode.masks)
+ )
+ episode.cumulative_reward = torch.sum(episode.rewards)
+ episode.done = True
+ episode._idx = cfg.episode_length
+ return episode
+
+ @property
+ def first(self):
+ return len(self) == 0
+
+ @property
+ def full(self):
+ return len(self) == self.capacity
+
+ @property
+ def buffer_capacity(self):
+ return self.capacity
+
+ def __add__(self, transition):
+ self.add(*transition)
+ return self
+
+ def add(self, obs, action, reward, done, timeouts, success=False):
+ if isinstance(obs, dict):
+ for k, v in obs.items():
+ if self._idx == self.capacity - 1:
+ self.next_obses[k][:, self._idx] = v.clone().to(self.obses[k].device, dtype=self.obses[k].dtype)
+ elif self._idx < self.capacity - 1:
+ self.obses[k][:, self._idx + 1] = v.clone().to(self.obses[k].device, dtype=self.obses[k].dtype)
+ self.next_obses[k][:, self._idx] = self.obses[k][:, self._idx + 1].clone()
+ else:
+ if self._idx == self.capacity - 1:
+ self.next_obses[:, self._idx] = obs.clone().to(self.obses.device, dtype=self.obses.dtype)
+ elif self._idx < self.capacity - 1:
+ self.obses[:, self._idx + 1] = obs.clone().to(self.obses.device, dtype=self.obses.dtype)
+ self.next_obses[:, self._idx] = self.obses[:, self._idx + 1].clone()
+ self.actions[:, self._idx] = action.detach().cpu()
+ self.rewards[:, self._idx] = reward.detach().cpu()
+ self.dones[:, self._idx] = done.detach().cpu()
+ self.masks[:, self._idx] = 1.0 - timeouts.detach().cpu().float() # TODO
+ self.cumulative_reward += reward.detach().cpu()
+ self.done = done.detach().cpu()
+ self.success = torch.logical_or(self.success, success.detach().cpu())
+ self.successes[:, self._idx] = torch.tensor(self.success).to(self.device)
+ self._idx += 1
+
+
+class ReplayBuffer:
+ """
+ Storage and sampling functionality for training TD-MPC / TOLD.
+ The replay buffer is stored in GPU memory when training from state.
+ Uses prioritized experience replay by default."""
+
+ def __init__(self, cfg, dataset=None):
+ self.cfg = cfg
+ self.buffer_device = torch.device(cfg.buffer_device)
+ self.device = torch.device(cfg.device)
+ self.batch_size = self.cfg.batch_size
+
+ print("Replay buffer device: ", self.buffer_device)
+ print("Replay buffer sample device: ", self.device)
+
+ if dataset is not None:
+ self.capacity = max(dataset["rewards"].shape[0], cfg.max_offline_buffer_size)
+ print("Offline dataset size: ", dataset["rewards"].shape[0])
+ else:
+ self.capacity = max(cfg.train_steps, cfg.max_buffer_size)
+
+ print("Maximum capacity of the buffer is: ", self.capacity)
+
+ if cfg.modality in {"pixels", "state"}:
+ dtype = torch.float32 if cfg.modality == "state" else torch.uint8
+ # Note self.obs_shape always has single frame, which is different from cfg.obs_shape
+ self.obs_shape = cfg.obs_shape if cfg.modality == "state" else (3, *cfg.obs_shape[-2:])
+ self._obs = torch.empty((self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device)
+ self._next_obs = torch.empty((self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device)
+ elif cfg.modality == "all":
+ self.obs_shape = {}
+ self._obs, self._next_obs = {}, {}
+ for k, v in cfg.obs_shape.items():
+ assert k in {"rgb", "state"}
+ dtype = torch.float32 if k == "state" else torch.uint8
+ self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
+ self._obs[k] = torch.empty(
+ (self.capacity, *self.obs_shape[k]),
+ dtype=dtype,
+ device=self.buffer_device,
+ )
+ self._next_obs[k] = self._obs[k].clone()
+ else:
+ raise ValueError
+
+ self._action = torch.empty(
+ (self.capacity, cfg.action_dim),
+ dtype=torch.float32,
+ device=self.buffer_device,
+ )
+ self._reward = torch.empty((self.capacity,), dtype=torch.float32, device=self.buffer_device)
+ self._mask = torch.empty((self.capacity,), dtype=torch.float32, device=self.buffer_device)
+ self._done = torch.empty((self.capacity,), dtype=torch.bool, device=self.buffer_device)
+ self._success = torch.empty((self.capacity,), dtype=torch.bool, device=self.buffer_device)
+ self._priorities = torch.ones((self.capacity,), dtype=torch.float32, device=self.buffer_device)
+ self.ep_len = int(self.cfg.max_episode_length // self.cfg.action_repeat)
+ self._eps = 1e-6
+ self._full = False
+ self.idx = 0
+ self._sampling_idx = 0
+ if dataset is not None:
+ self.init_from_offline_dataset(dataset)
+
+ self._aug = aug(cfg)
+
+ def init_from_offline_dataset(self, dataset):
+ assert self.idx == 0 and not self._full
+ n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
+
+ def copy_data(dst, src, n):
+ assert isinstance(dst, dict) == isinstance(src, dict)
+ if isinstance(dst, dict):
+ for k in dst:
+ copy_data(dst[k], src[k], n)
+ else:
+ dst[:n] = torch.from_numpy(src[:n])
+
+ copy_data(self._obs, dataset["observations"], n_transitions)
+ copy_data(self._next_obs, dataset["next_observations"], n_transitions)
+ copy_data(self._action, dataset["actions"], n_transitions)
+ if self.cfg.task.startswith("xarm"):
+ # success = self._calc_sparse_success(dataset['success'])
+ # copy_data(self._reward, success.astype(np.float32), n_transitions)
+ if self.cfg.sparse_reward:
+ copy_data(self._reward, dataset["success"].astype(np.float32), n_transitions)
+ copy_data(self._success, dataset["success"], n_transitions)
+ else:
+ copy_data(self._reward, dataset["rewards"], n_transitions)
+ copy_data(self._mask, dataset["masks"], n_transitions)
+ copy_data(self._done, dataset["dones"], n_transitions)
+ self.idx = (self.idx + n_transitions) % self.capacity
+ self._full = n_transitions >= self.capacity
+ mask_idxs = np.array([n_transitions - i for i in range(1, self.cfg.horizon)])
+ # _, episode_ends, _ = get_trajectory_boundaries_and_returns(dataset)
+ # mask_idxs = np.array([np.array(episode_ends) - i for i in range(1, self.cfg.horizon)]).T.flatten()
+ mask_idxs = np.clip(mask_idxs, 0, n_transitions - 1)
+ self._priorities[mask_idxs] = 0
+
+ def __add__(self, episode: Episode):
+ self.add(episode)
+ return self
+
+ def add(self, episode: Episode):
+ idxs = torch.arange(self.idx, self.idx + self.cfg.num_envs * self.ep_len) % self.capacity
+ self._sampling_idx = (self.idx + self.cfg.num_envs * self.ep_len) % self.capacity
+ mask_copy = episode.masks.clone()
+ mask_copy[:, episode._idx - self.cfg.horizon :] = 0.0
+ if self.cfg.modality in {"pixels", "state"}:
+ self._obs[idxs] = (
+ episode.obses.flatten(0, 1) if self.cfg.modality == "state" else episode.obses[:, -3:].flatten(0, 1)
+ )
+ self._next_obs[idxs] = (
+ episode.next_obses.flatten(0, 1)
+ if self.cfg.modality == "state"
+ else episode.next_obses[:, -3:].flatten(0, 1)
+ )
+ elif self.cfg.modality == "all":
+ for k, v in episode.obses.items():
+ assert k in {"rgb", "state"}
+ assert k in self._obs
+ assert k in self._next_obs
+ if k == "rgb":
+ self._obs[k][idxs] = episode.obses[k][:-1, -3:].flatten(0, 1)
+ self._next_obs[k][idxs] = episode.obses[k][1:, -3:].flatten(0, 1)
+ else:
+ self._obs[k][idxs] = episode.obses[k][:-1].flatten(0, 1)
+ self._next_obs[k][idxs] = episode.obses[k][1:].flatten(0, 1)
+ self._action[idxs] = episode.actions.flatten(0, 1)
+ self._reward[idxs] = episode.rewards.flatten(0, 1)
+ self._mask[idxs] = mask_copy.flatten(0, 1)
+ self._done[idxs] = episode.dones.flatten(0, 1)
+ self._success[idxs] = episode.successes.flatten(0, 1)
+ if self._full:
+ max_priority = self._priorities.max().to(self.device).item()
+ else:
+ max_priority = 1.0 if self.idx == 0 else self._priorities[: self.idx].max().to(self.device).item()
+ mask = torch.arange(self.ep_len) > self.ep_len - self.cfg.horizon
+ mask = torch.cat([mask] * self.cfg.num_envs)
+ new_priorities = torch.full((self.ep_len * self.cfg.num_envs,), max_priority, device=self.buffer_device)
+ new_priorities[mask] = 0
+ new_priorities = new_priorities * self._mask[idxs]
+ self._priorities[idxs] = new_priorities
+ self._full = self._full or (self.idx + self.ep_len * self.cfg.num_envs > self.capacity)
+ if episode.full:
+ self.idx = (self.idx + self.ep_len * self.cfg.num_envs) % self.capacity
+
+ def _set_bs(self, bs):
+ self.batch_size = bs
+
+ def update_priorities(self, idxs, priorities):
+ self._priorities[idxs.to(self.buffer_device)] = priorities.squeeze(1).to(self.buffer_device) + self._eps
+
+ def _get_obs(self, arr, idxs, bs=None, frame_stack=None):
+ if isinstance(arr, dict):
+ return {k: self._get_obs(v, idxs, bs=bs, frame_stack=frame_stack) for k, v in arr.items()}
+ if arr.ndim <= 2: # if self.cfg.modality == 'state':
+ return arr[idxs].cuda(self.device)
+ obs = torch.empty(
+ (
+ self.cfg.batch_size if bs is None else bs,
+ 3 * self.cfg.frame_stack if frame_stack is None else 3 * frame_stack,
+ *arr.shape[-2:],
+ ),
+ dtype=arr.dtype,
+ device=torch.device(self.device),
+ )
+ obs[:, -3:] = arr[idxs].cuda(self.device)
+ _idxs = idxs.clone()
+ mask = torch.ones_like(_idxs, dtype=torch.bool)
+ for i in range(1, self.cfg.frame_stack if frame_stack is None else frame_stack):
+ mask[_idxs % self.cfg.episode_length == 0] = False
+ _idxs[mask] -= 1
+ obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda(self.device)
+ return obs.float()
+
+ def save(self, buffer_fp):
+ data = {
+ "obs": self._obs.cpu(),
+ "next_obs": self._next_obs.cpu(),
+ "action": self._action.cpu(),
+ "reward": self._reward.cpu(),
+ "mask": self._mask.cpu(),
+ "done": self._done.cpu(),
+ "priorities": self._priorities.cpu(),
+ }
+ torch.save(data, buffer_fp)
+
+ def load(self, buffer_fp):
+ data = torch.load(buffer_fp)
+ n_transitions = data["obs"].shape[0]
+
+ if n_transitions >= self.capacity:
+ self._obs = self._obs[-self.capacity :]
+ self._next_obs = data["next_obs"][-self.capacity :].to(self.buffer_device)
+ self._action = data["action"][-self.capacity :].to(self.buffer_device)
+ self._reward = data["reward"][-self.capacity :].to(self.buffer_device)
+ self._mask = data["mask"][-self.capacity :].to(self.buffer_device)
+ self._done = data["done"][-self.capacity :].to(self.buffer_device)
+ self._priorities = data["priorities"][-self.capacity :].to(self.buffer_device)
+ self.idx = 0
+ self._full = True
+ else:
+ self._obs[:n_transitions] = self._obs
+ self._next_obs[:n_transitions] = data["next_obs"].to(self.buffer_device)
+ self._action[:n_transitions] = data["action"].to(self.buffer_device)
+ self._reward[:n_transitions] = data["reward"].to(self.buffer_device)
+ self._mask[:n_transitions] = data["mask"].to(self.buffer_device)
+ self._done[:n_transitions] = data["done"].to(self.buffer_device)
+ self._priorities[:n_transitions] = data["priorities"].to(self.buffer_device)
+ self.idx = n_transitions
+ self._full = False
+
+ def sample(self, bs=None):
+ probs = (self._priorities if self._full else self._priorities[: self._sampling_idx]) ** self.cfg.per_alpha
+ probs /= probs.sum()
+ total = len(probs)
+ if torch.isnan(self._priorities).any():
+ print(torch.isnan(self._priorities).any())
+ print(torch.where(torch.isnan(self._priorities)))
+ idxs = (
+ torch.from_numpy(
+ np.random.choice(
+ total,
+ self.cfg.batch_size if bs is None else bs,
+ p=probs.cpu().numpy(),
+ replace=((not self._full) or (self.cfg.batch_size > self.capacity)),
+ )
+ ).to(self.buffer_device)
+ % self.capacity
+ )
+ weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
+ weights /= weights.max()
+
+ idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)]) % self.capacity
+
+ obs = self._aug(self._get_obs(self._obs, idxs, bs=bs))
+ next_obs = [self._aug(self._get_obs(self._next_obs, _idxs, bs=bs)) for _idxs in idxs_in_horizon]
+ if isinstance(next_obs[0], dict):
+ next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
+ else:
+ next_obs = torch.stack(next_obs)
+ action = self._action[idxs_in_horizon]
+ reward = self._reward[idxs_in_horizon]
+ mask = self._mask[idxs_in_horizon]
+ done = torch.logical_not(self._done[idxs_in_horizon]).float()
+
+ if not action.is_cuda:
+ action, reward, done, idxs, weights = (
+ action.cuda(self.device),
+ reward.cuda(self.device),
+ done.cuda(self.device),
+ idxs.cuda(self.device),
+ weights.cuda(self.device),
+ )
+ return (
+ obs,
+ next_obs,
+ action,
+ reward.unsqueeze(2),
+ done.unsqueeze(2),
+ idxs,
+ weights,
+ )
diff --git a/sim/algo/tdmpc/src/algorithm/tdmpc.py b/sim/algo/tdmpc/src/algorithm/tdmpc.py
new file mode 100644
index 00000000..09028197
--- /dev/null
+++ b/sim/algo/tdmpc/src/algorithm/tdmpc.py
@@ -0,0 +1,280 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from copy import deepcopy
+from dataclasses import asdict
+import sim.algo.tdmpc.src.algorithm.helper as h
+
+
+class TOLD(nn.Module):
+ """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
+
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self._encoder = h.enc(cfg)
+ self._dynamics = h.mlp(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim, cfg.latent_dim)
+ self._reward = h.mlp(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim, 1)
+ self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim)
+ self._Qs = nn.ModuleList([h.q(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim) for _ in range(cfg.num_q)])
+ self.apply(h.orthogonal_init)
+ for m in [self._reward, *self._Qs]:
+ m[-1].weight.data.fill_(0)
+ m[-1].bias.data.fill_(0)
+
+ def track_q_grad(self, enable=True):
+ """Utility function. Enables/disables gradient tracking of Q-networks."""
+ for m in self._Qs:
+ h.set_requires_grad(m, enable)
+
+ def h(self, obs):
+ """Encodes an observation into its latent representation (h)."""
+ return self._encoder(obs)
+
+ def next(self, z, a):
+ """Predicts next latent state (d) and single-step reward (R)."""
+ x = torch.cat([z, a], dim=-1)
+ return self._dynamics(x), self._reward(x)
+
+ def pi(self, z, std=0):
+ """Samples an action from the learned policy (pi)."""
+ mu = self.cfg.max_clip_actions * torch.tanh(self._pi(z))
+ if std > 0:
+ std = torch.ones_like(mu) * std
+ return h.TruncatedNormal(mu, std, low=-self.cfg.max_clip_actions, high=self.cfg.max_clip_actions).sample(
+ clip=0.3
+ )
+ return mu
+
+ def Q(self, z, a):
+ """Predict state-action value (Q)."""
+ x = torch.cat([z, a], dim=-1)
+ Qs = torch.stack([self._Qs[i](x) for i in range(self.cfg.num_q)], dim=0)
+ return Qs
+
+
+class TDMPC:
+ """Implementation of TD-MPC learning + inference."""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.device = torch.device(cfg.device)
+ self.std = h.linear_schedule(cfg.std_schedule, 0)
+ self.model = TOLD(cfg).to(self.device)
+ self.model_target = deepcopy(self.model)
+ self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
+ self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=3 * self.cfg.lr)
+ self.aug = nn.Identity()
+ self.model.eval()
+ self.model_target.eval()
+
+ def state_dict(self, step=0, env_step=0):
+ """Retrieve state dict of TOLD model, including slow-moving target network."""
+ return {
+ "model": self.model.state_dict(),
+ "model_target": self.model_target.state_dict(),
+ "config": asdict(self.cfg),
+ "step": step,
+ "env_step": env_step,
+ }
+
+ def save(self, fp):
+ """Save state dict of TOLD model to filepath."""
+ torch.save(self.save_dict(), fp)
+
+ def load(self, fp):
+ """Load a saved state dict from filepath into current agent."""
+ d = torch.load(fp)
+ self.model.load_state_dict(d["model"])
+ self.model_target.load_state_dict(d["model_target"])
+ return d["step"], d["env_step"]
+
+ @torch.no_grad()
+ def estimate_value(self, z, actions, horizon):
+ """Estimate value of a trajectory starting at latent state z and executing given actions."""
+ G, discount = 0, 1
+ for t in range(horizon):
+ z, reward = self.model.next(z, actions[t])
+ G += discount * reward
+ discount *= self.cfg.discount
+ G += discount * torch.min(self.model.Q(z, self.model.pi(z, self.cfg.min_std)), dim=0)[0]
+ return G
+
+ @torch.no_grad()
+ def plan(self, obs, eval_mode=False, step=None, t0=True):
+ """
+ Plan next action using TD-MPC inference.
+ obs: raw input observation.
+ eval_mode: uniform sampling and action noise is disabled during evaluation.
+ step: current time step. determines e.g. planning horizon.
+ t0: whether current step is the first step of an episode.
+ """
+ clip_actions = h.linear_schedule(self.cfg.clip_actions, step=step)
+ # Seed steps
+ if step < self.cfg.seed_steps and not eval_mode:
+ return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_(
+ -clip_actions, clip_actions
+ )
+
+ # Sample policy trajectories
+ obs = obs.clone().to(self.device, dtype=torch.float32).unsqueeze(1)
+ horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
+ num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
+ if num_pi_trajs > 0:
+ pi_actions = torch.empty(horizon, self.cfg.num_envs, num_pi_trajs, self.cfg.action_dim, device=self.device)
+ z = self.model.h(obs).repeat(1, num_pi_trajs, 1)
+ for t in range(horizon):
+ pi_actions[t] = self.model.pi(z, self.cfg.min_std)
+ z, _ = self.model.next(z, pi_actions[t])
+
+ # Initialize state and parameters
+ z = self.model.h(obs).repeat(1, self.cfg.num_samples + num_pi_trajs, 1)
+ mean = torch.zeros(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device)
+ std = 1.5 * torch.ones(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) * clip_actions
+
+ if isinstance(t0, bool) and t0 and hasattr(self, "_prev_mean") and self._prev_mean.shape[0] > 1:
+ _prev_h = self._prev_mean.shape[0] - 1
+ mean[:_prev_h] = self._prev_mean[1:]
+ elif torch.is_tensor(t0) and t0.any() and hasattr(self, "_prev_mean") and self._prev_mean.shape[0] > 1:
+ _prev_h = self._prev_mean.shape[0] - 1
+ mean[:_prev_h] = self._prev_mean[1:]
+
+ # Iterate CEM
+ for i in range(self.cfg.iterations):
+ actions = torch.clamp(
+ mean.unsqueeze(2)
+ + std.unsqueeze(2)
+ * torch.randn(horizon, self.cfg.num_envs, self.cfg.num_samples, self.cfg.action_dim, device=std.device),
+ -clip_actions,
+ clip_actions,
+ )
+ if num_pi_trajs > 0:
+ actions = torch.cat([actions, pi_actions], dim=-2)
+
+ # Compute elite actions
+ value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
+ elite_idxs = torch.topk(value.squeeze(-1), self.cfg.num_elites, dim=-1).indices
+ elite_value, elite_actions = value.squeeze(-1).gather(-1, elite_idxs), actions.gather(
+ -2, elite_idxs.unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim)
+ )
+
+ # Update parameters
+ max_value = elite_value.max(1, keepdim=True)[0]
+ score = torch.exp(self.cfg.temperature * (elite_value - max_value))
+ score /= score.sum(1, keepdim=True)
+ _mean = torch.sum(score.unsqueeze(0).unsqueeze(-1) * elite_actions, dim=-2) / (
+ score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9
+ )
+ _std = torch.sqrt(
+ torch.sum(score.unsqueeze(0).unsqueeze(-1) * (elite_actions - _mean.unsqueeze(2)) ** 2, dim=-2)
+ / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9)
+ )
+ _std = _std.clamp_(self.std, 1.5 * clip_actions)
+ mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std
+
+ # Outputs
+ select_indices = torch.multinomial(score, 1)
+ actions = elite_actions.gather(
+ -2, select_indices.unsqueeze(0).unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim)
+ ).squeeze(-2)
+ self._prev_mean = mean
+ mean, std = actions[0], _std[0]
+ a = mean
+ if not eval_mode:
+ a = h.TruncatedNormal(a, std, low=-clip_actions, high=clip_actions).sample()
+ return a
+
+ def update_pi(self, zs):
+ """Update policy using a sequence of latent states."""
+ self.pi_optim.zero_grad(set_to_none=True)
+ self.model.track_q_grad(False)
+
+ # Loss is a weighted sum of Q-values
+ pi_loss = 0
+ for t, z in enumerate(zs):
+ a = self.model.pi(z, self.cfg.min_std)
+ Q = torch.min(self.model.Q(z, a), dim=0)[0]
+ pi_loss += -Q.mean() * (self.cfg.rho**t)
+
+ pi_loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False)
+ self.pi_optim.step()
+ self.model.track_q_grad(True)
+ return pi_loss.item()
+
+ @torch.no_grad()
+ def _td_target(self, next_obs, reward, mask=1.0):
+ """Compute the TD-target from a reward and the observation at the following time step."""
+ next_z = self.model.h(next_obs)
+ td_target = (
+ reward
+ + self.cfg.discount
+ * mask
+ * torch.min(self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)), dim=0)[0]
+ )
+ return td_target
+
+ def update(self, replay_buffer, step):
+ """Main update function. Corresponds to one iteration of the TOLD model learning."""
+ obs, next_obses, action, reward, mask, idxs, weights = replay_buffer.sample()
+ self.optim.zero_grad(set_to_none=True)
+ self.std = h.linear_schedule(self.cfg.std_schedule, step)
+ self.model.train()
+
+ # Representation
+ z = self.model.h(self.aug(obs))
+ zs = [z.detach()]
+
+ loss_mask = torch.ones_like(mask[0], device=self.device)
+
+ consistency_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0
+ for t in range(self.cfg.horizon):
+ if t > 0:
+ loss_mask = loss_mask * mask[t - 1]
+ # Predictions
+ Qs = self.model.Q(z, action[t])
+ z, reward_pred = self.model.next(z, action[t])
+ with torch.no_grad():
+ next_obs = self.aug(next_obses[t])
+ next_z = self.model_target.h(next_obs)
+ td_target = self._td_target(next_obs, mask[t], reward[t])
+ zs.append(z.detach())
+
+ # Losses
+ rho = self.cfg.rho**t
+ consistency_loss += loss_mask[t] * rho * torch.mean(h.mse(z, next_z), dim=1, keepdim=True)
+ reward_loss += loss_mask[t] * rho * h.mse(reward_pred, reward[t])
+ for i in range(self.cfg.num_q):
+ value_loss += loss_mask[t] * rho * h.mse(Qs[i], td_target)
+ priority_loss += loss_mask[t] * rho * h.l1(Qs[i], td_target)
+
+ # Optimize model
+ total_loss = (
+ self.cfg.consistency_coef * consistency_loss.clamp(max=1e4)
+ + self.cfg.reward_coef * reward_loss.clamp(max=1e4)
+ + self.cfg.value_coef * value_loss.clamp(max=1e4)
+ )
+ weighted_loss = (total_loss.squeeze(1) * weights).mean()
+ weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
+ weighted_loss.backward()
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
+ )
+ self.optim.step()
+ replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach())
+
+ if step % self.cfg.update_freq == 0:
+ # Update policy + target network
+ pi_loss = self.update_pi(zs)
+ h.ema(self.model, self.model_target, self.cfg.tau)
+
+ self.model.eval()
+ return {
+ "consistency_loss": float(consistency_loss.mean().item()),
+ "reward_loss": float(reward_loss.mean().item()),
+ "value_loss": float(value_loss.mean().item()),
+ "pi_loss": pi_loss if step % self.cfg.update_freq == 0 else 0.0,
+ "total_loss": float(total_loss.mean().item()),
+ "weighted_loss": float(weighted_loss.mean().item()),
+ "grad_norm": float(grad_norm),
+ }
diff --git a/sim/algo/tdmpc/src/logger.py b/sim/algo/tdmpc/src/logger.py
new file mode 100644
index 00000000..cdc6f9b7
--- /dev/null
+++ b/sim/algo/tdmpc/src/logger.py
@@ -0,0 +1,189 @@
+import sys
+import os
+import datetime
+import re
+import cv2
+import numpy as np
+import torch
+import pandas as pd
+from isaacgym import gymapi
+from termcolor import colored
+
+
+CONSOLE_FORMAT = [
+ ("episode", "E", "int"),
+ ("step", "S", "int"),
+ ("env_step", "ES", "int"),
+ ("episode_reward", "R", "float"),
+ ("mean_episode_length", "MEL", "float"),
+ ("total_time", "T", "time"),
+]
+# ('consistency_loss', 'CL', 'float'), ('value_loss', 'PL', 'float'), ('pi_loss', 'PL', 'float'), ('total_loss', 'L', 'float')]
+AGENT_METRICS = ["consistency_loss", "reward_loss", "value_loss", "total_loss", "weighted_loss", "pi_loss", "grad_norm"]
+
+
+def make_dir(dir_path):
+ """Create directory if it does not already exist."""
+ try:
+ os.makedirs(dir_path)
+ except OSError:
+ pass
+ return dir_path
+
+
+def print_run(cfg, reward=None):
+ """Pretty-printing of run information. Call at start of training."""
+ prefix, color, attrs = " ", "green", ["bold"]
+
+ def limstr(s, maxlen=32):
+ return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
+
+ def pprint(k, v):
+ print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v))
+
+ kvs = [
+ ("task", cfg.task),
+ ("train steps", f"{int(cfg.train_steps*cfg.action_repeat):,}"),
+ ("observations", "x".join([str(s) for s in cfg.obs_shape])),
+ ("actions", cfg.action_dim),
+ ("experiment", cfg.exp_name),
+ ]
+ if reward is not None:
+ kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
+ w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
+ div = "-" * w
+ print(div)
+ for k, v in kvs:
+ pprint(k, v)
+ print(div)
+
+
+def cfg_to_group(cfg, return_list=False):
+ """Return a wandb-safe group name for logging. Optionally returns group name as list."""
+ lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
+ return lst if return_list else "-".join(lst)
+
+
+class VideoRecorder:
+ """Utility class for logging evaluation videos."""
+
+ def __init__(self, root_dir, wandb, render_size=384, fps=15):
+ self.save_dir = (root_dir / "eval_video") if root_dir else None
+ self._wandb = wandb
+ self.render_size = render_size
+ self.fps = fps
+ self.frames = []
+ self.enabled = False
+
+ def init(self, env, h1, enabled=True):
+ self.frames = []
+ self.enabled = self.save_dir and self._wandb and enabled
+ self.record(env, h1)
+
+ def record(self, env, h1):
+ if self.enabled:
+ env.gym.fetch_results(env.sim, True)
+ env.gym.step_graphics(env.sim)
+ env.gym.render_all_camera_sensors(env.sim)
+ img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR)
+ img = np.reshape(img, (120, 160, 4))
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
+ self.frames.append(img)
+
+ def save(self, step):
+ if self.enabled:
+ frames = np.stack(self.frames).transpose(0, 3, 1, 2)
+ self._wandb.log({"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")}, step=step)
+
+
+class Logger(object):
+ """Primary logger object. Logs either locally or using wandb."""
+
+ def __init__(self, log_dir, cfg):
+ project, entity = cfg.wandb_project, cfg.wandb_entity
+ run_offline = not cfg.use_wandb or project == "none" or entity == "none"
+
+ self._save_model = cfg.save_model
+ if not run_offline or self._save_model:
+ self._log_dir = make_dir(log_dir)
+
+ if self._save_model:
+ self._model_dir = make_dir(self._log_dir / "models")
+
+ self._group = cfg_to_group(cfg)
+ self._seed = cfg.seed
+ self._cfg = cfg
+ self._eval = []
+ print_run(cfg)
+ if run_offline:
+ print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
+ self._wandb = None
+ else:
+ try:
+ os.environ["WANDB_SILENT"] = "true"
+ import wandb
+
+ wandb.init(
+ project=project,
+ entity=entity,
+ name=str(cfg.seed),
+ group=self._group,
+ tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
+ dir=self._log_dir,
+ )
+ print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
+ self._wandb = wandb
+ except:
+ print(colored("Warning: failed to init wandb. Logs will be saved locally.", "yellow", attrs=["bold"]))
+ self._wandb = None
+ self._video = VideoRecorder(log_dir, self._wandb) if (self._wandb and cfg.save_video) else None
+
+ @property
+ def video(self):
+ return self._video
+
+ def save(self, agent, model_name="model.pt", step=None, env_step=None):
+ if self._save_model:
+ fp = self._model_dir / f"{model_name}"
+ torch.save(agent.state_dict(step=step, env_step=env_step), fp)
+
+ def finish(self, agent, model_name="model.pt"):
+ if self._save_model:
+ fp = self._model_dir / f"{model_name}"
+ torch.save(agent.state_dict(), fp)
+ # if self._wandb:
+ # artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model')
+ # artifact.add_file(fp)
+ # self._wandb.log_artifact(artifact)
+ if self._wandb:
+ self._wandb.finish()
+ print_run(self._cfg, self._eval[-1][-1])
+
+ def _format(self, key, value, ty):
+ if ty == "int":
+ return f'{colored(key+":", "grey")} {int(value):,}'
+ elif ty == "float":
+ return f'{colored(key+":", "grey")} {value:.03f}'
+ elif ty == "time":
+ value = str(datetime.timedelta(seconds=int(value)))
+ return f'{colored(key+":", "grey")} {value}'
+ else:
+ raise f"invalid log format type: {ty}"
+
+ def _print(self, d, category):
+ category = colored(category, "blue" if category == "train" else "green")
+ pieces = [f" {category:<14}"]
+ for k, disp_k, ty in CONSOLE_FORMAT:
+ pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}")
+ print(" ".join(pieces))
+
+ def log(self, d, category="train"):
+ assert category in {"train", "eval"}
+ if self._wandb is not None:
+ for k, v in d.items():
+ self._wandb.log({category + "/" + k: v}, step=d["env_step"])
+ if category == "eval":
+ keys = ["env_step", "episode_reward"]
+ self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
+ pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / "eval.log", header=keys, index=None)
+ self._print(d, category)
diff --git a/sim/play_tdmpc.py b/sim/play_tdmpc.py
new file mode 100644
index 00000000..29670401
--- /dev/null
+++ b/sim/play_tdmpc.py
@@ -0,0 +1,220 @@
+"""Trains a humanoid to stand up."""
+
+import argparse
+import isaacgym
+import os
+import cv2
+import torch
+from sim.envs import task_registry
+from sim.utils.helpers import get_args
+from sim.algo.tdmpc.src import logger
+from sim.algo.tdmpc.src.algorithm.helper import Episode, ReplayBuffer
+from sim.algo.tdmpc.src.algorithm.tdmpc import TDMPC
+from dataclasses import dataclass, field
+from isaacgym import gymapi
+from typing import List
+import time
+import numpy as np
+from pathlib import Path
+import random
+from datetime import datetime
+from tqdm import tqdm
+
+torch.backends.cudnn.benchmark = True
+__LOGS__ = "logs"
+
+
+@dataclass
+class TDMPC_DoraConfigs:
+ seed: int = 42
+ task: str = "walk"
+ exp_name: str = "dora"
+ device: str = "cuda:0"
+ num_envs: int = 10
+ episode_length: int = 100
+ max_episode_length: int = 1000
+ max_clip_actions: float = 1.0
+ clip_actions: str = f"linear(1, {max_clip_actions}, 100000)"
+
+ lr: float = 1e-3
+ modality: str = "state"
+ enc_dim: int = 512 # 256
+ mlp_dim = [512, 256] # [256, 256]
+ latent_dim: int = 100
+
+ iterations: int = 12
+ num_samples: int = 512
+ num_elites: int = 50
+ mixture_coef: float = 0.05
+ min_std: float = 0.05
+ temperature: float = 0.5
+ momentum: float = 0.1
+ horizon: int = 5
+ std_schedule: str = f"linear(0.5, {min_std}, 3000)"
+ horizon_schedule: str = f"linear(1, {horizon}, 2500)"
+
+ batch_size: int = 1024
+ max_buffer_size: int = 1000000
+ reward_coef: float = 1
+ value_coef: float = 0.5
+ consistency_coef: float = 2
+ rho: float = 0.5
+ kappa: float = 0.1
+ per_alpha: float = 0.6
+ per_beta: float = 0.4
+ grad_clip_norm: float = 10
+ seed_steps: int = 750
+ update_freq: int = 2
+ tau: int = 0.01
+
+ discount: float = 0.99
+ buffer_device: str = "cpu"
+ train_steps: int = int(1e6)
+ num_q: int = 3
+
+ action_repeat: int = 2
+ eval_freq: int = 15000
+ eval_freq_episode: int = 10
+ eval_episodes: int = 1
+
+ save_model: bool = True
+ save_video: bool = False
+ eval_model: bool = False
+ save_buffer: bool = True
+ eval_freq_episode: int = 10
+ eval_episodes: int = 1
+ save_buffer_freq_episode: int = 50
+ save_model_freq_episode: int = 10
+
+ use_wandb: bool = False
+ wandb_entity: str = "crajagopalan"
+ wandb_project: str = "xbot"
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+class VideoRecorder:
+ """Utility class for logging evaluation videos."""
+
+ def __init__(self, root_dir, render_size=384, fps=25):
+ self.save_dir = (root_dir / "eval_video") if root_dir else None
+ self.render_size = render_size
+ self.fps = fps
+ self.frames = []
+ self.enabled = False
+ fourcc = cv2.VideoWriter_fourcc(*"MP4V") # type: ignore[attr-defined]
+ logger.make_dir(self.save_dir)
+ now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ # Creates a directory to store videos.
+ dir = os.path.join(self.save_dir, now + ".mp4")
+ self.video = cv2.VideoWriter(dir, fourcc, float(fps), (1920, 1080))
+
+ def init(self, env, h1, enabled=True):
+ self.frames = []
+ self.enabled = self.save_dir and enabled
+ self.record(env, h1)
+
+ def record(self, env, h1):
+ env.gym.fetch_results(env.sim, True)
+ env.gym.step_graphics(env.sim)
+ env.gym.render_all_camera_sensors(env.sim)
+ img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR)
+ img = np.reshape(img, (1080, 1920, 4))
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
+ self.video.write(img)
+
+ def save(
+ self,
+ ):
+ self.video.release()
+
+
+def evaluate(test_env, agent, h1, step, video, action_repeat=1):
+ """Evaluate a trained agent and optionally save a video."""
+ episode_rewards = []
+ obs, privileged_obs = test_env.reset()
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs
+ dones, ep_reward, t = torch.tensor([False] * test_env.num_envs), torch.tensor([0.0] * test_env.num_envs), 0
+ if video:
+ video.init(test_env, h1, enabled=True)
+ for i in tqdm(range(int(1000 // action_repeat))):
+ actions = agent.plan(state, eval_mode=True, step=step, t0=t == 0)
+ for _ in range(action_repeat):
+ obs, privileged_obs, rewards, dones, infos = test_env.step(actions)
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ ep_reward += rewards.cpu()
+ t += 1
+ if video:
+ video.record(test_env, h1)
+ state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs
+ episode_rewards.append(ep_reward)
+ if video:
+ video.save()
+ print(f"Timestep : {t} Episode Rewards - {torch.cat(episode_rewards).mean().item()}")
+ return torch.nanmean(torch.cat(episode_rewards)).item()
+
+
+def play(args: argparse.Namespace) -> None:
+ """Training script for TD-MPC. Requires a CUDA-enabled device."""
+ assert torch.cuda.is_available()
+ env_cfg, _ = task_registry.get_cfgs(name=args.task)
+ env, _ = task_registry.make_env(name=args.task, args=args)
+
+ fp = "/home/guest/sim/logs/2024-09-17_00-45-33_walk_state_dora/models/tdmpc_policy_350.pt"
+ config = torch.load(fp)["config"]
+ tdmpc_cfg = TDMPC_DoraConfigs(**config)
+ env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat)
+
+ camera_properties = gymapi.CameraProperties()
+ camera_properties.width = 1920
+ camera_properties.height = 1080
+ h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties)
+ camera_offset = gymapi.Vec3(3, -3, 1)
+ camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135))
+ actor_handle = env.gym.get_actor_handle(env.envs[0], 0)
+ body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0)
+ env.gym.attach_camera_to_body(
+ h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION
+ )
+
+ set_seed(tdmpc_cfg.seed)
+ work_dir = (
+ Path().cwd() / __LOGS__ / f"{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}_{str(tdmpc_cfg.seed)}"
+ )
+
+ obs, privileged_obs = env.reset()
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0]
+
+ tdmpc_cfg.obs_shape = [state.shape[0]]
+ tdmpc_cfg.action_shape = env.num_actions
+ tdmpc_cfg.action_dim = env.num_actions
+ tdmpc_cfg.episode_length = 100 # int(env.max_episode_length // tdmpc_cfg.action_repeat)
+ tdmpc_cfg.num_envs = env.num_envs
+
+ L = logger.Logger(work_dir, tdmpc_cfg)
+ log_dir = logger.make_dir(work_dir)
+ video = VideoRecorder(log_dir)
+
+ agent = TDMPC(tdmpc_cfg)
+
+ agent.load(fp)
+ step = 0
+ episode_idx, start_time = 0, time.time()
+ if fp is not None:
+ episode_idx = int(fp.split(".")[0].split("_")[-1])
+ step = episode_idx * tdmpc_cfg.episode_length
+
+ # Log training episode
+ evaluate(env, agent, h1, step, video, tdmpc_cfg.action_repeat)
+ print("Testing completed successfully")
+
+
+if __name__ == "__main__":
+ play(get_args())
diff --git a/sim/resources/stompymini/__init__.py b/sim/resources/stompymini/__init__.py
new file mode 100755
index 00000000..e69de29b
diff --git a/sim/resources/stompymini/robot.xml b/sim/resources/stompymini/robot.xml
new file mode 100755
index 00000000..db5f3f38
--- /dev/null
+++ b/sim/resources/stompymini/robot.xml
@@ -0,0 +1,598 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/sim/resources/stompymini/robot_fixed.xml b/sim/resources/stompymini/robot_fixed.xml
new file mode 100755
index 00000000..aa4981d9
--- /dev/null
+++ b/sim/resources/stompymini/robot_fixed.xml
@@ -0,0 +1,961 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/sim/train_tdmpc.py b/sim/train_tdmpc.py
new file mode 100644
index 00000000..d8391882
--- /dev/null
+++ b/sim/train_tdmpc.py
@@ -0,0 +1,253 @@
+"""Trains a humanoid to stand up."""
+
+import argparse
+import isaacgym
+import torch
+from sim.envs import task_registry
+from sim.utils.helpers import get_args
+from sim.algo.tdmpc.src import logger
+from sim.algo.tdmpc.src.algorithm.helper import Episode, ReplayBuffer
+from sim.algo.tdmpc.src.algorithm.tdmpc import TDMPC
+from dataclasses import dataclass, field, asdict
+from datetime import datetime
+from isaacgym import gymapi
+from typing import List
+import time
+import numpy as np
+from pathlib import Path
+import random
+
+torch.backends.cudnn.benchmark = True
+__LOGS__ = "logs"
+
+
+@dataclass
+class TDMPC_DoraConfigs:
+ seed: int = 42
+ task: str = "walk"
+ exp_name: str = "dora"
+ device: str = "cuda:0"
+ num_envs: int = 10
+ max_clip_actions: float = 1.0
+ clip_actions: str = f"linear(1, {max_clip_actions}, 100000)"
+ episode_length: int = 100
+ max_episode_length: int = 1000
+
+ lr: float = 1e-3
+ modality: str = "state"
+ enc_dim: int = 512
+ mlp_dim = [512, 256]
+ latent_dim: int = 100
+
+ iterations: int = 12
+ num_samples: int = 512
+ num_elites: int = 50
+ mixture_coef: float = 0.05
+ min_std: float = 0.1
+ temperature: float = 0.5
+ momentum: float = 0.1
+ horizon: int = 5
+ std_schedule: str = f"linear(0.5, {min_std}, 200000)"
+ horizon_schedule: str = f"linear(1, {horizon}, 15000)"
+
+ batch_size: int = 8192
+ max_buffer_size: int = int(5e6)
+ reward_coef: float = 1
+ value_coef: float = 0.75
+ consistency_coef: float = 2
+ rho: float = 0.75
+ kappa: float = 0.1
+ per_alpha: float = 0.6
+ per_beta: float = 0.4
+ grad_clip_norm: float = 50
+ seed_steps: int = 500
+ update_freq: int = 3
+ tau: int = 0.05
+
+ discount: float = 0.99
+ buffer_device: str = "cpu"
+ train_steps: int = int(1e6)
+ num_q: int = 2
+
+ action_repeat: int = 2
+
+ save_model: bool = True
+ save_video: bool = False
+ save_buffer: bool = True
+ eval_model: bool = False
+ eval_freq_episode: int = 10
+ eval_episodes: int = 1
+ save_buffer_freq_episode: int = 50
+ save_model_freq_episode: int = 50
+
+ use_wandb: bool = False
+ wandb_entity: str = "crajagopalan"
+ wandb_project: str = "xbot"
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def evaluate(test_env, agent, h1, num_episodes, step, env_step, video, action_repeat=1):
+ """Evaluate a trained agent and optionally save a video."""
+ episode_rewards = []
+ for i in range(num_episodes):
+ obs, privileged_obs = test_env.reset()
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0]
+ dones, ep_reward, t = torch.tensor([False]), 0, 0
+ if video:
+ video.init(test_env, h1, enabled=(i == 0))
+ while not dones[0].item():
+ actions = agent.plan(state, eval_mode=True, step=step, t0=t == 0)
+ for _ in range(action_repeat):
+ obs, privileged_obs, rewards, dones, infos = test_env.step(actions)
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0]
+ ep_reward += rewards[0]
+ if video:
+ video.record(test_env, h1)
+ t += 1
+ episode_rewards.append(ep_reward)
+ if video:
+ video.save(env_step)
+ return torch.nanmean(torch.tensor(episode_rewards)).item()
+
+
+def train(args: argparse.Namespace) -> None:
+ """Training script for TD-MPC. Requires a CUDA-enabled device."""
+ assert torch.cuda.is_available()
+ env_cfg, train_cfg = task_registry.get_cfgs(name=args.task)
+ env, _ = task_registry.make_env(name=args.task, args=args)
+
+ tdmpc_cfg = TDMPC_DoraConfigs()
+
+ if tdmpc_cfg.save_video:
+ env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat)
+
+ camera_properties = gymapi.CameraProperties()
+ camera_properties.width = 160
+ camera_properties.height = 120
+ h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties)
+ camera_offset = gymapi.Vec3(3, -3, 1)
+ camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135))
+ actor_handle = env.gym.get_actor_handle(env.envs[0], 0)
+ body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0)
+ env.gym.attach_camera_to_body(
+ h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION
+ )
+
+ set_seed(tdmpc_cfg.seed)
+ now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ work_dir = Path().cwd() / __LOGS__ / f"{now}_{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}"
+
+ obs, privileged_obs = env.reset()
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs
+
+ tdmpc_cfg.obs_shape = state.shape[1:]
+ tdmpc_cfg.action_shape = env.num_actions
+ tdmpc_cfg.action_dim = env.num_actions
+ episode_length = 100
+ tdmpc_cfg.episode_length = episode_length
+ tdmpc_cfg.num_envs = env.num_envs
+ tdmpc_cfg.max_episode_length = int(env.max_episode_length)
+ tdmpc_cfg.max_clip_actions = env.cfg.normalization.clip_actions
+ tdmpc_cfg.clip_actions = f"linear(1, {env.cfg.normalization.clip_actions}, 50000)"
+ print(tdmpc_cfg.clip_actions)
+
+ agent = TDMPC(tdmpc_cfg)
+ buffer = ReplayBuffer(tdmpc_cfg)
+ fp = None
+
+ init_step, env_step = 0, 0
+ episode_idx, start_time = 0, time.time()
+
+ if fp is not None:
+ init_step, env_step = agent.load(fp)
+ if init_step is None:
+ episode_idx = int(fp.split(".")[0].split("_")[-1])
+ init_step = episode_idx * tdmpc_cfg.episode_length
+ env_step = init_step * tdmpc_cfg.episode_length
+ episode = Episode(tdmpc_cfg, state)
+
+ L = logger.Logger(work_dir, tdmpc_cfg)
+
+ for step in range(init_step, tdmpc_cfg.train_steps + tdmpc_cfg.episode_length, tdmpc_cfg.episode_length):
+ if episode.full:
+ episode = Episode(tdmpc_cfg, state)
+ for i in range(tdmpc_cfg.episode_length):
+ actions = agent.plan(state, t0=episode.first, eval_mode=False, step=step)
+ original_state = state.clone()
+ total_rewards, total_dones, total_timeouts = [], [], []
+ for _ in range(tdmpc_cfg.action_repeat):
+ obs, privileged_obs, rewards, dones, infos = env.step(actions)
+ critic_obs = privileged_obs if privileged_obs is not None else obs
+ state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs
+ total_rewards.append(rewards)
+ total_dones.append(dones)
+ total_timeouts.append(infos["time_outs"])
+ episode += (
+ original_state,
+ actions,
+ torch.stack(total_rewards).sum(dim=0),
+ torch.stack(total_dones).any(dim=0),
+ torch.stack(total_timeouts).any(dim=0),
+ torch.stack(total_dones).any(dim=0),
+ )
+ buffer += episode
+
+ # Update model
+ train_metrics = {}
+ if step >= tdmpc_cfg.seed_steps:
+ num_updates = tdmpc_cfg.seed_steps if step == tdmpc_cfg.seed_steps else tdmpc_cfg.episode_length
+ for i in range(num_updates):
+ train_metrics.update(agent.update(buffer, step + i))
+
+ # Log training episode
+ episode_idx += 1
+ env_step += int(tdmpc_cfg.episode_length * tdmpc_cfg.num_envs)
+ common_metrics = {
+ "episode": episode_idx,
+ "step": step + tdmpc_cfg.episode_length,
+ "env_step": env_step,
+ "total_time": time.time() - start_time,
+ "episode_reward": episode.cumulative_reward.sum().item() / env.num_envs,
+ "mean_episode_length": episode.episode_length,
+ }
+ train_metrics.update(common_metrics)
+ L.log(train_metrics, category="train")
+
+ # Evaluate agent periodically
+ if tdmpc_cfg.save_model and episode_idx % tdmpc_cfg.save_model_freq_episode == 0:
+ L.save(
+ agent,
+ f"tdmpc_policy_{int(step // tdmpc_cfg.episode_length) + 1}.pt",
+ step + tdmpc_cfg.episode_length,
+ env_step,
+ )
+ if tdmpc_cfg.save_buffer and episode_idx % tdmpc_cfg.save_buffer_freq_episode == 0:
+ buffer.save(str(work_dir / "buffer.pt"))
+ if tdmpc_cfg.eval_model and episode_idx % tdmpc_cfg.eval_freq_episode == 0:
+ common_metrics["episode_reward"] = evaluate(
+ env,
+ agent,
+ h1 if L.video is not None else None,
+ tdmpc_cfg.eval_episodes,
+ step,
+ env_step,
+ L.video,
+ tdmpc_cfg.action_repeat,
+ )
+ L.log(common_metrics, category="eval")
+
+ print("Training completed successfully")
+
+
+if __name__ == "__main__":
+ # python -m sim.humanoid_gym.train
+ train(get_args())