-
Notifications
You must be signed in to change notification settings - Fork 4
/
envs.py
124 lines (105 loc) · 4.46 KB
/
envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
from typing import Optional, Tuple, List
from gym.spaces import Box
from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_
from gym.wrappers import TimeLimit
from copy import deepcopy
class HalfCheetahEnv(HalfCheetahEnv_):
def _get_obs(self):
return np.concatenate([
self.sim.data.qpos.flat[1:],
self.sim.data.qvel.flat,
self.get_body_com("torso").flat,
]).astype(np.float32).flatten()
def viewer_setup(self):
camera_id = self.model.camera_name2id('track')
self.viewer.cam.type = 2
self.viewer.cam.fixedcamid = camera_id
self.viewer.cam.distance = self.model.stat.extent * 0.35
# Hide the overlay
self.viewer._hide_overlay = True
def render(self, mode='human'):
if mode == 'rgb_array':
self._get_viewer(mode).render()
# window size used for old mujoco-py:
width, height = 500, 500
data = self._get_viewer().read_pixels(width, height, depth=False)
return data
elif mode == 'human':
self._get_viewer(mode).render()
class HalfCheetahDirEnv_(HalfCheetahEnv):
"""Half-cheetah environment with target direction, as described in [1]. The
code is adapted from
https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand_direc.py
The half-cheetah follows the dynamics from MuJoCo [2], and receives at each
time step a reward composed of a control cost and a reward equal to its
velocity in the target direction. The tasks are generated by sampling the
target directions from a Bernoulli distribution on {-1, 1} with parameter
0.5 (-1: backward, +1: forward).
[1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic
Meta-Learning for Fast Adaptation of Deep Networks", 2017
(https://arxiv.org/abs/1703.03400)
[2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for
model-based control", 2012
(https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf)
"""
def __init__(self, task={}, n_tasks=2, randomize_tasks=False):
directions = [-1, 1]
self.tasks = [{'direction': direction} for direction in directions]
self._task = task
self._goal_dir = task.get('direction', 1)
self._goal = self._goal_dir
super(HalfCheetahDirEnv_, self).__init__()
def step(self, action):
xposbefore = self.sim.data.qpos[0]
self.do_simulation(action, self.frame_skip)
xposafter = self.sim.data.qpos[0]
forward_vel = (xposafter - xposbefore) / self.dt
forward_reward = self._goal_dir * forward_vel
ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action))
observation = self._get_obs()
reward = forward_reward - ctrl_cost
done = False
infos = dict(reward_forward=forward_reward,
reward_ctrl=-ctrl_cost, task=self._task)
return (observation, reward, done, infos)
def sample_tasks(self, num_tasks):
directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1
tasks = [{'direction': direction} for direction in directions]
return tasks
def get_all_task_idx(self):
return range(len(self.tasks))
def reset_task(self, idx):
self._task = self.tasks[idx]
self._goal_dir = self._task['direction']
self._goal = self._goal_dir
self.reset()
class HalfCheetahDirEnv(HalfCheetahDirEnv_):
def __init__(self, tasks: List[dict], include_goal: bool = False):
self.include_goal = include_goal
super(HalfCheetahDirEnv, self).__init__()
if tasks is None:
tasks = [{'direction': 1}, {'direction': -1}]
self.tasks = tasks
self.set_task_idx(0)
self._max_episode_steps = 200
def _get_obs(self):
if self.include_goal:
idx = 0
try:
idx = self.tasks.index(self._task)
except:
pass
one_hot = np.zeros(len(self.tasks), dtype=np.float32)
one_hot[idx] = 1.0
obs = super()._get_obs()
obs = np.concatenate([obs, one_hot])
else:
obs = super()._get_obs()
return obs
def set_task(self, task):
self._task = task
self._goal_dir = self._task['direction']
self.reset()
def set_task_idx(self, idx):
self.set_task(self.tasks[idx])