diff --git a/examples/locomotion/go2_env.py b/examples/locomotion/go2_env.py index a23f7ed4..4368b8ab 100644 --- a/examples/locomotion/go2_env.py +++ b/examples/locomotion/go2_env.py @@ -150,6 +150,11 @@ def step(self, actions): self.reset_buf = self.episode_length_buf > self.max_episode_length self.reset_buf |= torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"] self.reset_buf |= torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"] + + time_out_idx = (self.episode_length_buf > self.max_episode_length).nonzero(as_tuple=False).flatten() + self.extras["time_outs"] = torch.zeros_like(self.reset_buf, device=self.device, dtype=gs.tc_float) + self.extras["time_outs"][time_out_idx] = 1.0 + self.reset_idx(self.reset_buf.nonzero(as_tuple=False).flatten()) # compute reward