From e817152b078171fde35faa7176bb4b95573a07b8 Mon Sep 17 00:00:00 2001 From: Deniz Seven Date: Sat, 3 Aug 2024 12:08:22 +0200 Subject: [PATCH 1/7] fix: move sample inside gradient step loop in TD3RNN, DDPGRNN, SAC, SACRNN, DQN, DDQN --- docs/source/api/agents/ddqn.rst | 4 ++-- docs/source/api/agents/dqn.rst | 4 ++-- docs/source/api/agents/sac.rst | 4 ++-- skrl/agents/jax/dqn/ddqn.py | 9 +++++---- skrl/agents/jax/dqn/dqn.py | 7 ++++--- skrl/agents/jax/sac/sac.py | 7 ++++--- skrl/agents/jax/td3/td3.py | 1 + skrl/agents/torch/ddpg/ddpg_rnn.py | 17 +++++++++-------- skrl/agents/torch/dqn/ddqn.py | 9 +++++---- skrl/agents/torch/dqn/dqn.py | 7 ++++--- skrl/agents/torch/sac/sac.py | 9 +++++---- skrl/agents/torch/sac/sac_rnn.py | 17 +++++++++-------- skrl/agents/torch/td3/td3_rnn.py | 17 +++++++++-------- 13 files changed, 61 insertions(+), 51 deletions(-) diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst index 5208192e..095cd967 100644 --- a/docs/source/api/agents/ddqn.rst +++ b/docs/source/api/agents/ddqn.rst @@ -40,10 +40,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`Q' \leftarrow Q_{\phi_{target}}(s')` | :math:`Q_{_{target}} \leftarrow Q'[\underset{a}{\arg\max} \; Q_\phi(s')] \qquad` :gray:`# the only difference with DQN` diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst index 74e2ca04..d5d72e7b 100644 --- a/docs/source/api/agents/dqn.rst +++ b/docs/source/api/agents/dqn.rst @@ -40,10 +40,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`Q' \leftarrow Q_{\phi_{target}}(s')` | :math:`Q_{_{target}} \leftarrow \underset{a}{\max} \; Q' \qquad` :gray:`# the only difference with DDQN` diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst index b465acbb..cf1e4265 100644 --- a/docs/source/api/agents/sac.rst +++ b/docs/source/api/agents/sac.rst @@ -34,10 +34,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`a',\; logp' \leftarrow \pi_\theta(s')` | :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')` diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index f11ebe9b..4a727b3b 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -337,13 +337,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): - + + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index e75247d7..8c245eb4 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -334,13 +334,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index 8f77da3b..b487c291 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -397,13 +397,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index 6a1c909c..9a96f58d 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -436,6 +436,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 436184c1..2ace360d 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -361,18 +361,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index 8d181ac8..c2ebc04c 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -284,13 +284,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] - + # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index c7f4f709..7ee9e842 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -284,13 +284,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index dc8678a6..63de9935 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -308,13 +308,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] - + # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 6553d7aa..21a86fb7 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -345,18 +345,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index abd6a922..1baab8d5 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -383,18 +383,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) From fd729084652ee5c3984548c26e82dcd7364b9d7b Mon Sep 17 00:00:00 2001 From: Deniz Seven Date: Sat, 3 Aug 2024 12:09:36 +0200 Subject: [PATCH 2/7] chore: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c61e475b..d090fff2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed - Move the KL reduction from the PyTorch `KLAdaptiveLR` class to each agent using it in distributed runs - Move the PyTorch distributed initialization from the agent base class to the ML framework configuration +- Moved the batch sampling inside gradient step loop for DDPGRNN, TD3RNN, SAC, SACRNN, DQN and DDQN. ### Fixed - Catch TensorBoard summary iterator exceptions in `TensorboardFileIterator` postprocessing utils From 99cf16bc801786ece2727f0892269fccbb71a541 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sat, 2 Nov 2024 15:28:24 -0400 Subject: [PATCH 3/7] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d090fff2..a0112deb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed - Move the KL reduction from the PyTorch `KLAdaptiveLR` class to each agent using it in distributed runs - Move the PyTorch distributed initialization from the agent base class to the ML framework configuration -- Moved the batch sampling inside gradient step loop for DDPGRNN, TD3RNN, SAC, SACRNN, DQN and DDQN. +- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN) ### Fixed - Catch TensorBoard summary iterator exceptions in `TensorboardFileIterator` postprocessing utils From a3689e3022c0d3817b703f5404de4aa42f786559 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sun, 3 Nov 2024 09:51:29 -0500 Subject: [PATCH 4/7] Apply format to ddqn.py in jax --- skrl/agents/jax/dqn/ddqn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index 679eab18..76868e68 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -339,11 +339,11 @@ def _update(self, timestep: int, timesteps: int) -> None: # gradient steps for gradient_step in range(self._gradient_steps): - + # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] - + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) From f93ee335a2c7dca3cd07af827a187d33603dafa3 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sun, 3 Nov 2024 09:52:21 -0500 Subject: [PATCH 5/7] Apply format to td3.py in jax --- skrl/agents/jax/td3/td3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index e0ffe646..23f4885a 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -435,7 +435,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # gradient steps for gradient_step in range(self._gradient_steps): - + # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] From 5b9787c727462910220da77f8f35a54594c00283 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sun, 3 Nov 2024 09:54:44 -0500 Subject: [PATCH 6/7] Apply format to dqn.py in torch --- skrl/agents/torch/dqn/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 27925e6f..03ffa320 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -290,7 +290,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] - + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) From 4cdd7f4378d8904783305141276865e6ef81cdf1 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sun, 3 Nov 2024 09:55:31 -0500 Subject: [PATCH 7/7] Apply format to ddqn.py in torch --- skrl/agents/torch/dqn/ddqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index 36469cb7..d7e93886 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -283,7 +283,7 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - + # gradient steps for gradient_step in range(self._gradient_steps):