Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Sampling inside gradient loop issue #183

3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0)
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9).

### Fixed
- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)

### Removed
- Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments,
it is just not installed as part of the library. If it is needed, it needs to be installed manually.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/ddqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`Q' \leftarrow Q_{\phi_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow Q'[\underset{a}{\arg\max} \; Q_\phi(s')] \qquad` :gray:`# the only difference with DQN`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`Q' \leftarrow Q_{\phi_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow \underset{a}{\max} \; Q' \qquad` :gray:`# the only difference with DDQN`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`a',\; logp' \leftarrow \pi_\theta(s')`
| :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')`
Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
1 change: 1 addition & 0 deletions skrl/agents/jax/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def _update(self, timestep: int, timesteps: int) -> None:

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
Expand Down
17 changes: 9 additions & 8 deletions skrl/agents/torch/ddpg/ddpg_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,19 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/torch/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/torch/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
3 changes: 2 additions & 1 deletion skrl/agents/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,13 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
Expand Down
17 changes: 9 additions & 8 deletions skrl/agents/torch/sac/sac_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,18 +344,19 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
17 changes: 9 additions & 8 deletions skrl/agents/torch/td3/td3_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,19 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
Loading