From 19c85a1fd8fd56437c2f9c1de081889e49554604 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 11 Jul 2024 14:02:30 +0200 Subject: [PATCH] Add CNN support for DQN (#49) * Add CNN support for DQN * Update version and deps * Fix CNN, channel last, padding and reshape --- sbx/dqn/dqn.py | 4 +++- sbx/dqn/policies.py | 54 ++++++++++++++++++++++++++++++++++++++++++++- sbx/version.txt | 2 +- setup.py | 4 ++-- tests/test_cnn.py | 44 ++++++++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 tests/test_cnn.py diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index cde94b8..852b823 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -11,12 +11,13 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.dqn.policies import DQNPolicy +from sbx.dqn.policies import CNNPolicy, DQNPolicy class DQN(OffPolicyAlgorithmJax): policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment] "MlpPolicy": DQNPolicy, + "CnnPolicy": CNNPolicy, } # Linear schedule will be defined in `_setup_model()` exploration_schedule: Schedule @@ -36,6 +37,7 @@ def __init__( exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, + optimize_memory_usage: bool = False, # Note: unused but to match SB3 API # max_grad_norm: float = 10, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index b9226ee..d8b19ba 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -28,6 +28,32 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return x +# Add CNN policy from DQN paper +class NatureCNN(nn.Module): + n_actions: int + n_units: int = 512 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # Convert from channel-first (PyTorch) to channel-last (Jax) + x = jnp.transpose(x, (0, 2, 3, 1)) + # Convert to float and normalize the image + x = x.astype(jnp.float32) / 255.0 + x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) + x = self.activation_fn(x) + # Flatten + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(self.n_units)(x) + x = self.activation_fn(x) + x = nn.Dense(self.n_actions)(x) + return x + + class DQNPolicy(BaseJaxPolicy): action_space: spaces.Discrete # type: ignore[assignment] @@ -65,7 +91,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: obs = jnp.array([self.observation_space.sample()]) - self.qf = QNetwork( + self.qf: nn.Module = QNetwork( n_actions=int(self.action_space.n), n_units=self.n_units, activation_fn=self.activation_fn, @@ -97,3 +123,29 @@ def select_action(qf_state, observations): def _predict(self, observation: np.ndarray, deterministic: bool = True) -> np.ndarray: # type: ignore[override] return DQNPolicy.select_action(self.qf_state, observation) + + +class CNNPolicy(DQNPolicy): + def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: + key, qf_key = jax.random.split(key, 2) + + obs = jnp.array([self.observation_space.sample()]) + + self.qf = NatureCNN( + n_actions=int(self.action_space.n), + n_units=self.n_units, + activation_fn=self.activation_fn, + ) + + self.qf_state = RLTrainState.create( + apply_fn=self.qf.apply, + params=self.qf.init({"params": qf_key}, obs), + target_params=self.qf.init({"params": qf_key}, obs), + tx=self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), + ) + self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign] + + return key diff --git a/sbx/version.txt b/sbx/version.txt index 04a373e..c5523bd 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.16.0 +0.17.0 diff --git a/setup.py b/setup.py index 3540c3b..96c390f 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ ## Example ```python -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ +from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ model = TQC("MlpPolicy", "Pendulum-v1", verbose=1) model.learn(total_timesteps=10_000, progress_bar=True) @@ -40,7 +40,7 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.3.0", + "stable_baselines3>=2.4.0a4,<3.0", "jax", "jaxlib", "flax", diff --git a/tests/test_cnn.py b/tests/test_cnn.py new file mode 100644 index 0000000..78b209b --- /dev/null +++ b/tests/test_cnn.py @@ -0,0 +1,44 @@ +import numpy as np +import pytest +from stable_baselines3.common.envs import FakeImageEnv + +from sbx import DQN + + +@pytest.mark.parametrize("model_class", [DQN]) +def test_cnn(tmp_path, model_class): + SAVE_NAME = "cnn_model.zip" + # Fake grayscale with frameskip + # Atari after preprocessing: 84x84x1, here we are using lower resolution + # to check that the network handle it automatically + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1) + model = model_class( + "CnnPolicy", + env, + buffer_size=250, + policy_kwargs=dict(net_arch=[64]), + learning_starts=100, + verbose=1, + ) + model.learn(total_timesteps=250) + + obs, _ = env.reset() + + # Test stochastic predict with channel last input + if model_class == DQN: + model.exploration_rate = 0.9 + + for _ in range(10): + model.predict(obs, deterministic=False) + + action, _ = model.predict(obs, deterministic=True) + + model.save(tmp_path / SAVE_NAME) + del model + + model = model_class.load(tmp_path / SAVE_NAME) + + # Check that the prediction is the same + assert np.allclose(action, model.predict(obs, deterministic=True)[0]) + + (tmp_path / SAVE_NAME).unlink()