diff --git a/environment.yml b/environment.yml index 2623436e9..d17556a1d 100644 --- a/environment.yml +++ b/environment.yml @@ -48,6 +48,7 @@ dependencies: - git+https://github.com/Theano/Theano.git@adfe319ce6b781083d8dc3200fb4481b00853791#egg=Theano - git+https://github.com/neocxi/Lasagne.git@484866cf8b38d878e92d521be445968531646bb8#egg=Lasagne - git+https://github.com/plotly/plotly.py.git@2594076e29584ede2d09f2aa40a8a195b3f3fc66#egg=plotly + - git+https://github.com/deepmind/dm_control.git#egg=dm_control - awscli - git+https://github.com/openai/gym.git@v0.7.4#egg=gym - pyglet diff --git a/rllab/envs/dm_control_env.py b/rllab/envs/dm_control_env.py new file mode 100644 index 000000000..0a82bb2df --- /dev/null +++ b/rllab/envs/dm_control_env.py @@ -0,0 +1,83 @@ +import pygame +import numpy as np + +from dm_control import suite +from dm_control.rl.environment import StepType +from dm_control.rl.control import flatten_observation + +from rllab.envs.base import Env, Step +from rllab.envs.dm_control_viewer import DmControlViewer +from rllab.core.serializable import Serializable +from rllab.spaces.box import Box +from rllab.spaces.discrete import Discrete + + +class DmControlEnv(Env, Serializable): + ''' + This environment will use dm_control toolkit(https://arxiv.org/pdf/1801.00690.pdf) + to train and simulate your models. + ''' + + def __init__( + self, + domain_name, + task_name, + plot=False, + width=320, + height=240, + ): + Serializable.quick_init(self, locals()) + + self._env = suite.load(domain_name=domain_name, task_name=task_name) + + self._total_reward = 0 + self._render_kwargs = {'width': width, 'height': height} + + if plot: + self._viewer = DmControlViewer() + else: + self._viewer = None + + def step(self, action): + time_step = self._env.step(action) + if time_step.reward: + self._total_reward += time_step.reward + return Step(flatten_observation(time_step.observation), \ + time_step.reward, \ + time_step.step_type == StepType.LAST, \ + **time_step.observation) + + def reset(self): + self._total_reward = 0 + time_step = self._env.reset() + return flatten_observation(time_step.observation) + + def render(self): + if self._viewer: + pixels_img = self._env.physics.render(**self._render_kwargs) + self._viewer.loop_once(pixels_img) + + def terminate(self): + if self._viewer: + self._viewer.finish() + + def _flat_shape(self, observation): + return np.sum(int(np.prod(v.shape)) for k, v in observation.items()) + + @property + def action_space(self): + action_spec = self._env.action_spec() + if (len(action_spec.shape) == 1) and (-np.inf in action_spec.minimum or + np.inf in action_spec.maximum): + return Discrete(np.prod(action_spec.shape)) + else: + return Box(action_spec.minimum, action_spec.maximum) + + @property + def observation_space(self): + flat_dim = self._flat_shape(self._env.observation_spec()) + return Box(low=-np.inf, high=np.inf, shape=[flat_dim]) + + @property + def total_reward(self): + return self._total_reward diff --git a/rllab/envs/dm_control_viewer.py b/rllab/envs/dm_control_viewer.py new file mode 100644 index 000000000..a71d62351 --- /dev/null +++ b/rllab/envs/dm_control_viewer.py @@ -0,0 +1,24 @@ +import pygame +import numpy as np + +CAPTION = "dm_control viewer" + + +class DmControlViewer(): + def __init__(self): + pygame.init() + pygame.display.set_caption(CAPTION) + self.screen = None + + def loop_once(self, image): + image = np.swapaxes(image, 0, 1) + + if not self.screen: + self.screen = pygame.display.set_mode((image.shape[0], + image.shape[1])) + + pygame.surfarray.blit_array(self.screen, image) + pygame.display.flip() + + def finish(self): + pygame.quit() diff --git a/tests/test_dmcontrol.py b/tests/test_dmcontrol.py new file mode 100644 index 000000000..9a9ef81d8 --- /dev/null +++ b/tests/test_dmcontrol.py @@ -0,0 +1,36 @@ +import numpy as np + +from rllab.envs.dm_control_env import DmControlEnv +from rllab.envs.normalized_env import normalize + +from dm_control import suite + + +def run_task(domain_name, task_name): + print("run: domain %s task %s" % (domain_name, task_name)) + dmcontrol_env = normalize( + DmControlEnv( + domain_name=domain_name, + task_name=task_name, + plot=True, + width=600, + height=400), + normalize_obs=False, + normalize_reward=False) + + time_step = dmcontrol_env.reset() + action_spec = dmcontrol_env.action_space + for _ in range(200): + dmcontrol_env.render() + action = action_spec.sample() + next_obs, reward, done, info = dmcontrol_env.step(action) + if done == True: + break + + dmcontrol_env.terminate() + + +for domain, task in suite.ALL_TASKS: + run_task(domain, task) + +print("Congratulation! All tasks are done!")