diff --git a/luchador/agent/dqn.py b/luchador/agent/dqn.py index cfbf858c..0644b1dd 100644 --- a/luchador/agent/dqn.py +++ b/luchador/agent/dqn.py @@ -167,8 +167,6 @@ def __init__( self._saver = None self._ql = None self._eg = None - self._stack_buffer = None - self._previous_stack = None self._summary_writer = None self._summary_values = { 'errors': [], @@ -202,9 +200,7 @@ def _init_network(self, n_actions): ########################################################################### # Methods for `reset` - def reset(self, initial_observation): - self._stack_buffer = [initial_observation[0]] - self._previous_stack = None + def reset(self, _): self._ready = False ########################################################################### @@ -227,26 +223,17 @@ def _predict_q(self): # Methods for `learn` def learn(self, state0, action, reward, state1, terminal, info=None): self._n_obs += 1 - self._record(action, reward, state1, terminal) + self._record(state0, action, reward, state1, terminal) self._train() - def _record(self, action, reward, state1, terminal): + def _record(self, state0, action, reward, state1, terminal): """Stack states and push them to recorder, then sort memory""" - self._stack_buffer.append(state1[0]) + self._recorder.push(1, { + 'state0': state0, 'action': action, 'reward': reward, + 'state1': state1, 'terminal': terminal}) + self._ready = True cfg = self.args['record_config'] - if len(self._stack_buffer) == cfg['stack'] + 1: - if self._previous_stack is None: - self._previous_stack = np.array(self._stack_buffer[:-1]) - state0_ = self._previous_stack - state1_ = np.array(self._stack_buffer[1:]) - self._recorder.push(1, { - 'state0': state0_, 'action': action, 'reward': reward, - 'state1': state1_, 'terminal': terminal}) - self._stack_buffer = self._stack_buffer[1:] - self._previous_stack = state1_ - self._ready = True - sort_freq = cfg['sort_frequency'] if sort_freq > 0 and self._n_obs % sort_freq == 0: _LG.info('Sorting Memory') diff --git a/luchador/env/ale/ale.py b/luchador/env/ale/ale.py index 1cbff5f0..77fbf9b0 100644 --- a/luchador/env/ale/ale.py +++ b/luchador/env/ale/ale.py @@ -18,55 +18,61 @@ _ROM_DIR = os.path.join(_DIR, 'rom') -class Preprocessor(object): - """Store the latest frames and take max/mean over them +class StateStack(object): + """Stack multiple states Parameters ---------- - frame_shape : list of two int - Order is (height, width) - - channel : int - 1 or 3 - - buffer_size : int - The number of frames to process. Default: 2. - - mode : int - `max` or `mean` + stack_size : int + The number of states to stack. Default: 4. """ - def __init__(self, frame_shape, channel, buffer_size=2, mode='max'): - self.frame_shape = list(frame_shape) - self.buffer_size = buffer_size - self.channel = channel - self.mode = mode - - buffer_shape = [buffer_size, channel] + self.frame_shape - self._buffer = np.zeros(buffer_shape, dtype=np.uint8) - self._func = np.max if mode == 'max' else np.mean - self._index = 0 + def __init__(self, stack_size=4): + self.stack_size = stack_size + self._buffer = None - def reset(self, initial_frame): - """Reset buffer with new frame + def reset(self, initial_state): + """Reset stack buffer by filling it with initial state Parameters ---------- - initial_frame : NumPy Array - The initial observation obtained from resetting env + initial_state : state object """ - for _ in range(self.buffer_size): - self.append(initial_frame) + self._buffer = [initial_state] * self.stack_size - def append(self, frame): - """Update buffer with new frame + def append(self, state): + """Append new state and discard old state Parameters ---------- - frame : NumPy Array - The observation obtained by taking a step in env + initial_state : state object """ - self._buffer[self._index] = frame - self._index = (self._index + 1) % self.buffer_size + self._buffer.append(state) + self._buffer = self._buffer[-self.stack_size:] + + def get(self): + """Get the current stack + + Returns + ------- + list of states + """ + return self._buffer + + +class Preprocessor(StateStack): + """Store the latest frames and take max/mean over them + + Parameters + ---------- + buffer_size : int + The number of frames to process. Default: 2. + + mode : int + `max` or `mean` + """ + def __init__(self, buffer_size=2, mode='max'): + super(Preprocessor, self).__init__(stack_size=buffer_size) + self._func = np.max if mode == 'max' else np.mean def get(self): """Return preprocessed frame @@ -120,45 +126,70 @@ def _make_ale( class ALEEnvironment(StoreMixin, BaseEnvironment): - """Atari Environment""" - @staticmethod - def get_roms(): - """Get the list of ROMs available + """Atari Environment - Returns: - list of srting: Names of available ROMs - """ - return [rom for rom in os.listdir(_ROM_DIR) - if rom.endswith('.bin')] + Parameters + ---------- + rom : str + ROM name. Use `get_roms` for the list of available ROMs. - def _validate_args( - self, mode, preprocess_mode, - repeat_action, random_start, rom, **_): - if mode not in ['test', 'train']: - raise ValueError('`mode` must be either `test` or `train`') + mode : str + When `train`, a loss of life is considered as terminal condition. + When `test`, a loss of life is not considered as terminal condition. - if preprocess_mode not in ['max', 'average']: - raise ValueError( - '`preprocess_mode` must be either `max` or `average`') + width, height : int + Output screen size. - if repeat_action < 1: - raise ValueError( - '`repeat_action` must be integer greater than 0') + stack : int + Stack the environment state. The output shape of ``step`` is 4D, where + the first dimension is the stack. - if random_start and random_start < 1: - raise ValueError( - '`random_start` must be `None` or integer greater than 0' - ) + grayscale : bool + If True, output screen is gray scale and has no color channel. i.e. + output shape == (h, w). Otherwise output screen has color channel with + shape (h, w, 3) - rom_path = os.path.join(_ROM_DIR, rom) - if not os.path.isfile(rom_path): - raise ValueError('ROM ({}) not found.'.format(rom)) + repeat_action : int + When calling `step` method, action is repeated for this numebr of times + internally, unless a terminal condition is met. + + minimal_action_set : bool + When True, `n_actions` property reports actions only meaningfull to the + loaded ROM. Otherwise all the 18 actions are dounted. + + random_seed : int + ALE's random seed + + random_start : int or None + When given, at the beginning of each episode at most this number of + frames are played with action == 0. This technique is used to acquire + more diverse states of environment. + + buffer_frames : int + The number of latest frame to preprocess. + + preprocess_mode : str + Either `max` or `average`. When obtaining observation, pixel-wise + maximum or average over buffered frames are taken before resizing + + display_screen : bool + Display sceen when True. + + play_sound : bool + Play sound + record_screen_path : str + Passed to ALE. Save the raw screens into the path. + + record_screen_filename : str + Passed to ALE. Save sound to a file. + """ def __init__( self, rom, mode='train', width=160, height=210, + stack=4, grayscale=True, repeat_action=4, buffer_frames=2, @@ -171,70 +202,11 @@ def __init__( record_screen_path=None, record_sound_filename=None, ): - """Initialize ALE Environment with the given parmeters - - Parameters - ---------- - rom : str - ROM name. Use `get_roms` for the list of available ROMs. - - mode : str - When `train`, a loss of life is considered as terminal condition. - When `test`, a loss of life is not considered as terminal - condition. - - width, height : int - Output screen size. - - grayscale : bool - If True, output screen is gray scale and has no color channel. - i.e. output shape == (h, w). Otherwise output screen has color - channel with shape (h, w, 3) - - repeat_action : int - When calling `step` method, action is repeated for this numebr of - times internally, unless a terminal condition is met. - - minimal_action_set : bool - When True, `n_actions` property reports actions only meaningfull to - the loaded ROM. Otherwise all the 18 actions are dounted. - - random_seed : int - ALE's random seed - - random_start : int or None - When given, at the beginning of each episode at most this number - of frames are played with action == 0. This technique is used to - acquire more diverse states of environment. - - buffer_frames : int - The number of latest frame to preprocess. - - preprocess_mode : str - Either `max` or `average`. When obtaining observation, pixel-wise - maximum or average over buffered frames are taken before resizing - - display_screen : bool - Display sceen when True. - - play_sound : bool - Play sound - - record_screen_path : str - Passed to ALE. Save the original screens into the path. - - Note - that this is different from the observation returned by - `step` method. - - record_screen_filename : str - Passed to ALE. Save sound to a file. - """ if not rom.endswith('.bin'): rom += '.bin' self._store_args( - rom=rom, mode=mode, width=width, height=height, + rom=rom, mode=mode, width=width, height=height, stack=stack, grayscale=grayscale, repeat_action=repeat_action, buffer_frames=buffer_frames, preprocess_mode=preprocess_mode, minimal_action_set=minimal_action_set, random_seed=random_seed, @@ -252,56 +224,80 @@ def __init__( self._ale = _make_ale(**self.args) self._actions = ( self._ale.getMinimalActionSet() - if self.args['minimal_action_set'] else + if minimal_action_set else self._ale.getLegalActionSet() ) - self._get_raw_screen = ( self._ale.getScreenGrayscale - if self.args['grayscale'] else + if grayscale else self._ale.getScreenRGB ) - self._init_raw_buffer() - self._preprocessor = Preprocessor( - frame_shape=(self.args['height'], self.args['width']), - channel=1 if self.args['grayscale'] else 3, - buffer_size=self.args['buffer_frames'], - mode=self.args['preprocess_mode']) + self._raw_buffer_shape = ( + (210, 160) if self.args['grayscale'] else (210, 160, 3)) self._init_resize() + self._processor = Preprocessor( + buffer_size=buffer_frames, mode=preprocess_mode) + self._stack = StateStack(stack_size=stack) + + def _validate_args( + self, mode, preprocess_mode, + repeat_action, random_start, rom, **_): + if mode not in ['test', 'train']: + raise ValueError('`mode` must be either `test` or `train`') - def _init_raw_buffer(self): - w, h = self._ale.getScreenDims() - shape = (h, w) if self.args['grayscale'] else (h, w, 3) - self._raw_buffer = np.zeros(shape, dtype=np.uint8) + if preprocess_mode not in ['max', 'average']: + raise ValueError( + '`preprocess_mode` must be either `max` or `average`') + + if repeat_action < 1: + raise ValueError( + '`repeat_action` must be integer greater than 0') + + if random_start and random_start < 1: + raise ValueError( + '`random_start` must be `None` or integer greater than 0' + ) + + rom_path = os.path.join(_ROM_DIR, rom) + if not os.path.isfile(rom_path): + raise ValueError('ROM ({}) not found.'.format(rom)) def _init_resize(self): + """Initialize resize method""" orig_width, orig_height = self._ale.getScreenDims() h, w = self.args['height'], self.args['width'] if not (h == orig_height and w == orig_width): self.resize = (h, w) if self.args['grayscale'] else (h, w, 3) + ########################################################################### + @staticmethod + def get_roms(): + """Get the list of ROMs available + + Returns: + list of srting: Names of available ROMs + """ + return [rom for rom in os.listdir(_ROM_DIR) + if rom.endswith('.bin')] + ########################################################################### @property def n_actions(self): return len(self._actions) ########################################################################### + # Helper methods common to `reset` and `step` def _get_resized_frame(self): - """Fetch the current frame and resize then convert to CHW format""" - self._get_raw_screen(screen_data=self._raw_buffer) + """Fetch the current frame and resize""" + buffer_ = np.zeros(shape=self._raw_buffer_shape, dtype=np.uint8) + self._get_raw_screen(screen_data=buffer_) if self.resize: - screen = imresize(self._raw_buffer, self.resize) - else: - screen = self._raw_buffer - if self.args['grayscale']: - return screen[None, ...] - return screen.transpose((2, 0, 1)) + return imresize(buffer_, self.resize) + return buffer_ - def _random_play(self): - rand = self.args['random_start'] - repeat = 1 + (np.random.randint(rand) if rand else 0) - return sum(self._step(0) for _ in range(repeat)) + def _get_state(self): + return np.array(self._stack.get()) def _get_info(self): return { @@ -310,23 +306,38 @@ def _get_info(self): 'episode_frame_number': self._ale.getEpisodeFrameNumber(), } + def _is_terminal(self): + if self.args['mode'] == 'train': + return self._ale.game_over() or self.life_lost + return self._ale.game_over() + + ########################################################################### + def _random_play(self): + rand = self.args['random_start'] + repeat = 1 + (np.random.randint(rand) if rand else 0) + return sum(self._step(0) for _ in range(repeat)) + + def _reset(self): + """Actually reset game""" + self._ale.reset_game() + self._processor.reset(self._get_resized_frame()) + self._stack.reset(self._processor.get()) + rewards = self._random_play() + self._stack.append(self._processor.get()) + return rewards + def reset(self): """Reset game - In test mode, the game is simply initialized. In train mode, if the - game is in terminal state due to a life loss but not yet game over, - then only life loss flag is reset so that the next game starts from - the current state. Otherwise, the game is simply initialized. + In ``train`` mode, a loss of life is considered to be terminal state. + If this method is called at such state, then only life_lost flag is + reset so that the next episode can start from the next frame. """ - reward = 0 - if ( - self.args['mode'] == 'test' or - not self.life_lost or # `reset` called in a middle of episode - self._ale.game_over() # all lives are lost - ): - self._ale.reset_game() - self._preprocessor.reset(self._get_resized_frame()) - reward += self._random_play() + mode = self.args['mode'] + if mode == 'train' and self.life_lost and not self._ale.game_over(): + reward = 0 + else: + reward = self._reset() self.life_lost = False return Outcome( @@ -354,6 +365,7 @@ def step(self, action): if terminal: break + self._stack.append(self._processor.get()) return Outcome( reward=reward, state=self._get_state(), @@ -363,13 +375,5 @@ def step(self, action): def _step(self, action): reward = self._ale.act(action) - self._preprocessor.append(self._get_resized_frame()) + self._processor.append(self._get_resized_frame()) return reward - - def _get_state(self): - return self._preprocessor.get() - - def _is_terminal(self): - if self.args['mode'] == 'train': - return self._ale.game_over() or self.life_lost - return self._ale.game_over() diff --git a/tests/unit/env/ale/ale_test.py b/tests/unit/env/ale/ale_test.py index ff4a938c..ef2f2cdb 100644 --- a/tests/unit/env/ale/ale_test.py +++ b/tests/unit/env/ale/ale_test.py @@ -10,54 +10,47 @@ class ALEEnvShapeTest(unittest.TestCase): longMessage = True - def _test(self, width=160, height=210, grayscale=True): + def _test(self, width=160, height=210, stack=4, grayscale=True): ale = ALE( rom='breakout', - width=width, height=height, + stack=stack, width=width, height=height, grayscale=grayscale, ) ale.reset() outcome = ale.step(1) - channel = 1 if grayscale else 3 - self.assertEqual(outcome.state.shape, (channel, height, width)) + expected = ( + (stack, height, width) if grayscale else (stack, height, width, 3)) + self.assertEqual(outcome.state.shape, expected) def test_no_resize(self): """State shape equals to the original screen size""" - self._test(grayscale=True) + for gs in [True, False]: + self._test(grayscale=gs) + self._test(grayscale=gs, stack=1) def test_resize_width(self): """State shape equals to the given size""" - self._test(width=84, grayscale=True) + for gs in [True, False]: + self._test(width=84, grayscale=gs) + self._test(width=84, grayscale=gs, stack=1) def test_resize_height(self): """State shape equals to the given size""" - self._test(height=84, grayscale=True) + for gs in [True, False]: + self._test(height=84, grayscale=gs) + self._test(height=84, grayscale=gs, stack=1) def test_resize_width_height(self): """State shape equals to the given size""" - self._test(height=84, width=84, grayscale=True) + for gs in [True, False]: + self._test(height=84, width=84, grayscale=gs) + self._test(height=84, width=84, grayscale=gs, stack=1) - def test_no_resize_color(self): - """State shape equals to the original screen size""" - self._test(grayscale=False) - - def test_resize_width_color(self): - """State shape equals to the given size""" - self._test(width=84, grayscale=False) - - def test_resize_height_color(self): - """State shape equals to the given size""" - self._test(height=84, grayscale=False) - def test_resize_width_height_color(self): - """State shape equals to the given size""" - self._test(height=84, width=84, grayscale=False) - - -def _test_buffer(grayscale): +def _test_processor_buffer(grayscale): # pylint: disable=protected-access - buffer_frames = 4 + buffer_frames = 2 ale = ALE( rom='breakout', mode='train', @@ -65,33 +58,52 @@ def _test_buffer(grayscale): buffer_frames=buffer_frames, grayscale=grayscale, ) - buffer_ = ale._preprocessor._buffer - ale.reset() - frame = ale._get_raw_screen().transpose((2, 0, 1)) - np.testing.assert_equal(frame, buffer_[0]) + frame = ale._get_raw_screen().squeeze() + for i in range(buffer_frames): + np.testing.assert_equal(frame, ale._processor._buffer[i]) - for i in range(1, buffer_frames): + previous_frame = frame + for _ in range(buffer_frames): ale.step(1) - frame = ale._get_raw_screen().transpose((2, 0, 1)) - np.testing.assert_equal(frame, buffer_[i]) - - for _ in range(10): - for i in range(buffer_frames): - ale.step(1) - frame = ale._get_raw_screen().transpose((2, 0, 1)) - np.testing.assert_equal(frame, buffer_[i]) + frame = ale._get_raw_screen().squeeze() + np.testing.assert_equal(ale._processor._buffer[-1], frame) + np.testing.assert_equal(ale._processor._buffer[-2], previous_frame) + previous_frame = frame class PreprocessorTest(unittest.TestCase): # pylint: disable=no-self-use def test_buffer_frame(self): """The latest frame is correctly passed to preprocessor buffer""" - _test_buffer(grayscale=True) + _test_processor_buffer(grayscale=True) + + def test_buffer_frame_color(self): + """The latest frame is correctly passed to preprocessor buffer""" + _test_processor_buffer(grayscale=False) + + +class StackTest(unittest.TestCase): + # pylint: disable=no-self-use + def _test_stack_buffer(self, grayscale): + stack = 4 + ale = ALE(rom='breakout', stack=stack, grayscale=grayscale) + previous_stack = ale.reset().state + + for _ in range(stack): + stack = ale.step(1).state + np.testing.assert_equal(previous_stack[1:], stack[:-1]) + self.assertEqual(previous_stack.shape, stack.shape) + self.assertFalse((previous_stack == stack).all()) + previous_stack = stack + + def test_buffer_frame(self): + """The latest frame is correctly passed to preprocessor buffer""" + self._test_stack_buffer(grayscale=True) def test_buffer_frame_color(self): """The latest frame is correctly passed to preprocessor buffer""" - _test_buffer(grayscale=False) + self._test_stack_buffer(grayscale=False) class ALEEnvironmentTest(unittest.TestCase):