Skip to content

Commit

Permalink
separate out storage and computation devices in rollout buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
eugenevinitsky committed May 8, 2024
1 parent 8d30b22 commit 9908151
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions algorithms/ppo/sb3/rollout_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
device: Union[torch.device, str] = "auto",
storage_device: Union[torch.device, str] = "cpu", #TODO(ev) add storage device to config
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
Expand All @@ -39,48 +40,49 @@ def __init__(
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.storage_device = storage_device
self.reset()

def reset(self) -> None:
"""Reset the buffer."""
self.observations = torch.zeros(
(self.buffer_size, self.n_envs, *self.obs_shape),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.actions = torch.zeros(
(self.buffer_size, self.n_envs, self.action_dim),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.rewards = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.returns = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.episode_starts = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.values = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.log_probs = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.advantages = torch.zeros(
(self.buffer_size, self.n_envs),
device=self.device,
device=self.storage_device,
dtype=torch.float32,
)
self.generator_ready = False
Expand Down Expand Up @@ -110,12 +112,12 @@ def add(
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))

self.observations[self.pos] = obs
self.actions[self.pos] = action
self.rewards[self.pos] = reward
self.episode_starts[self.pos] = episode_start
self.values[self.pos] = value.flatten()
self.log_probs[self.pos] = log_prob.clone()
self.observations[self.pos] = obs.to(self.storage_device)
self.actions[self.pos] = action.to(self.storage_device)
self.rewards[self.pos] = reward.to(self.storage_device)
self.episode_starts[self.pos] = episode_start.to(self.storage_device)
self.values[self.pos] = value.flatten().to(self.storage_device)
self.log_probs[self.pos] = log_prob.clone().to(self.storage_device)
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
Expand All @@ -125,7 +127,8 @@ def compute_returns_and_advantage(
) -> None:
"""GAE (General Advantage Estimation) to compute advantages and returns."""
# Convert to numpy
last_values = last_values.clone().flatten()
last_values = last_values.clone().flatten().to(self.storage_device)
dones = dones.clone().flatten().to(self.storage_device)

last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
Expand Down

0 comments on commit 9908151

Please sign in to comment.