-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robust PLR #29
base: main
Are you sure you want to change the base?
Robust PLR #29
Changes from 7 commits
2b7e080
d098b40
3b3cea1
39aa1bc
992bdab
d92f64c
d24434b
03b6b77
cda5ddb
908688d
8364ff9
cb71cb8
75697e5
a6e2b31
c6b7ade
16fd2dc
91d7060
e40c3ad
4a1d9dc
adbf1c5
2ed4865
b3afe49
85c7504
2b281e6
1cc51c2
be3753d
81fa16e
2ea5979
31c63fe
b96686d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -47,7 +47,7 @@ def __init__( | |||||
staleness_transform: str = "power", | ||||||
staleness_temperature: float = 1.0, | ||||||
robust_plr: bool = False, | ||||||
eval_envs: List[gym.Env] = None, | ||||||
eval_envs = None, | ||||||
action_value_fn=None, | ||||||
): | ||||||
self.action_space = action_space | ||||||
|
@@ -109,6 +109,28 @@ def update_with_rollouts(self, rollouts): | |||||
|
||||||
self._update_with_rollouts(rollouts, score_function) | ||||||
|
||||||
def update_with_episode_data(self, episode_data): | ||||||
if self.strategy == "random": | ||||||
return | ||||||
|
||||||
# Update with a EpisodeRolloutStorage object | ||||||
if self.strategy == "policy_entropy": | ||||||
score_function = self._average_entropy | ||||||
elif self.strategy == "least_confidence": | ||||||
score_function = self._average_least_confidence | ||||||
elif self.strategy == "min_margin": | ||||||
score_function = self._average_min_margin | ||||||
elif self.strategy == "gae": | ||||||
score_function = self._average_gae | ||||||
elif self.strategy == "value_l1": | ||||||
score_function = self._average_value_l1 | ||||||
elif self.strategy == "one_step_td_error": | ||||||
score_function = self._one_step_td_error | ||||||
else: | ||||||
raise ValueError(f"Unsupported strategy, {self.strategy}") | ||||||
|
||||||
self._update_with_episode_data(episode_data, score_function) | ||||||
|
||||||
def update_task_score(self, actor_index, task_idx, score, num_steps): | ||||||
score = self._partial_update_task_score(actor_index, task_idx, score, num_steps, done=True) | ||||||
|
||||||
|
@@ -275,32 +297,66 @@ def _sample_unseen_level(self): | |||||
|
||||||
return task | ||||||
|
||||||
def compute_returns(self, gamma, gae_lambda, rewards, value_preds, masks): | ||||||
def _evaluate_unseen_level(self): | ||||||
sample_weights = self.unseen_task_weights / self.unseen_task_weights.sum() | ||||||
task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] | ||||||
task = self.tasks[task_idx] | ||||||
|
||||||
episode_data = self.evaluate_task(task, self.eval_envs, self.action_value_fn) | ||||||
self.update_with_episode_data(episode_data) | ||||||
|
||||||
self._update_staleness(task_idx) | ||||||
|
||||||
return task | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
No need to return the task here, see my other comments |
||||||
|
||||||
def compute_returns(self, rewards, values, masks, gamma, gae_lambda): | ||||||
assert self.requires_value_buffers, "Selected strategy does not use compute_rewards." | ||||||
|
||||||
if isinstance(rewards, float): | ||||||
rewards = [np.array(rewards)] | ||||||
|
||||||
num_steps = len(rewards) | ||||||
gae = 0 | ||||||
returns = torch.zeros_like(rewards) | ||||||
for step in reversed(range(rewards.size(0))): | ||||||
delta = ( | ||||||
rewards[step] | ||||||
+ gamma * value_preds[step + 1] * masks[step + 1] | ||||||
- value_preds[step] | ||||||
) | ||||||
gae = delta + gamma * gae_lambda * masks[step + 1] * gae | ||||||
returns[step] = gae + value_preds[step] | ||||||
returns = np.zeros_like(rewards) | ||||||
for step in reversed(range(num_steps)): | ||||||
# Check if we are at the last step | ||||||
if step == num_steps - 1: | ||||||
delta = (rewards[step] - values[step]) | ||||||
gae = delta | ||||||
else: | ||||||
next_value = values[step + 1] if step + 1 < num_steps else 0 | ||||||
next_mask = masks[step + 1] if step + 1 < num_steps else 0 | ||||||
delta = (rewards[step] | ||||||
+ gamma * next_value * next_mask | ||||||
- values[step] | ||||||
) | ||||||
gae = delta + gamma * gae_lambda * next_mask * gae | ||||||
|
||||||
gae_scal = gae[0] if isinstance(gae, np.ndarray) else gae | ||||||
value_scal = values[step][0] if isinstance(values[step], np.ndarray) else values[step] | ||||||
|
||||||
returns[step] = gae_scal + value_scal | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you changing this function? You should not need to modify the GAE code, it's already correct. |
||||||
|
||||||
return returns | ||||||
|
||||||
def evaluate_task(self, task, env, action_value_fn, gamma, gae_lambda): | ||||||
def evaluate_task(self, task, env, action_value_fn): | ||||||
if env is None: | ||||||
raise ValueError("Environment object is None. Please ensure it is properly initialized.") | ||||||
|
||||||
obs = env.reset(next_task=task) | ||||||
obs = env.reset(new_task=task) | ||||||
done = False | ||||||
rewards = [] | ||||||
masks = [] | ||||||
values = [] | ||||||
|
||||||
while not done: | ||||||
action, value = action_value_fn(obs) | ||||||
|
||||||
if isinstance(action, np.ndarray): | ||||||
action = int(action[0]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You shouldn't do this, instead make sure that action_value_fn returns a single action rather than an np.ndarray |
||||||
else: | ||||||
action = int(action) | ||||||
|
||||||
obs, rew, term, trunc, info = env.step(action) | ||||||
|
||||||
rewards.append(rew) | ||||||
|
@@ -311,34 +367,16 @@ def evaluate_task(self, task, env, action_value_fn, gamma, gae_lambda): | |||||
if term or trunc: | ||||||
done = True | ||||||
|
||||||
# Compute returns after the episode is complete | ||||||
returns = self.compute_returns(gamma, gae_lambda, rewards, values, masks) | ||||||
returns = self.compute_returns(rewards, values, masks, self.gamma, self.gae_lambda) | ||||||
|
||||||
return { | ||||||
"tasks": task, | ||||||
"masks": masks, | ||||||
"rewards": rewards, | ||||||
"value_preds": values, | ||||||
"values": values, | ||||||
"returns": returns | ||||||
} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once you clean up everything else, we should see how fast this is and come up with some ideas to optimize it |
||||||
|
||||||
def _sample_unseen_level(self): | ||||||
sample_weights = self.unseen_task_weights / self.unseen_task_weights.sum() | ||||||
task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] | ||||||
task = self.tasks[task_idx] | ||||||
|
||||||
self._update_staleness(task_idx) | ||||||
|
||||||
return task | ||||||
|
||||||
def _evaluate_unseen_level(self): | ||||||
task_idx = \ | ||||||
np.random.choice(range(self.num_tasks), 1, p=self.unseen_task_weights / self.unseen_task_weights.sum())[0] | ||||||
task = self.tasks[task_idx] | ||||||
episode_data = self.evaluate_task(task, self.eval_envs, self.action_value_fn, self.gamma, self.gae_lambda) | ||||||
self.update_with_episode_data(episode_data, self._average_gae) # Update task scores | ||||||
return task | ||||||
|
||||||
def sample(self, strategy=None): | ||||||
if not strategy: | ||||||
strategy = self.strategy | ||||||
|
@@ -363,44 +401,32 @@ def sample(self, strategy=None): | |||||
if np.random.rand() > self.nu or not proportion_seen < 1.0: | ||||||
return self._sample_replay_level() | ||||||
|
||||||
# Otherwise, evaluate a new level | ||||||
# Otherwise, sample a new level | ||||||
if self.robust_plr: | ||||||
self.update_with_episode_data(self._evaluate_unseen_level()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You call
Suggested change
|
||||||
return self.sample(strategy=strategy) | ||||||
else: | ||||||
# Otherwise, sample a new level | ||||||
return self._sample_unseen_level() | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good start, but it's going to be very inefficient to do these evaluations in the main process. We'll probably want to batch and multiprocess them in the future, but for now this is good as a proof of concept. |
||||||
|
||||||
elif self.replay_schedule == "proportionate": | ||||||
if proportion_seen >= self.rho and np.random.rand() < proportion_seen: | ||||||
return self._sample_replay_level() | ||||||
else: | ||||||
if self.robust_plr: | ||||||
while True: | ||||||
task = self._evaluate_unseen_level() | ||||||
episode_data = self.evaluate_task(task, self.eval_envs, self.action_value_fn, self.gamma, | ||||||
self.gae_lambda) | ||||||
self.update_with_episode_data(episode_data, self._average_gae) # Update task scores | ||||||
|
||||||
# Check if we need to sample another unseen level | ||||||
num_unseen = (self.unseen_task_weights > 0).sum() | ||||||
proportion_seen = (self.num_tasks - num_unseen) / self.num_tasks | ||||||
if proportion_seen < self.rho or np.random.rand() >= proportion_seen: | ||||||
break | ||||||
|
||||||
return task | ||||||
self.update_with_episode_data(self._evaluate_unseen_level()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You call
Suggested change
|
||||||
return self.sample(strategy=strategy) | ||||||
else: | ||||||
return self._sample_unseen_level() | ||||||
|
||||||
else: | ||||||
raise NotImplementedError( | ||||||
f"Unsupported replay schedule: {self.replay_schedule}. Must be 'fixed' or 'proportionate'.") | ||||||
|
||||||
def update_with_episode_data(self, episode_data, score_function): | ||||||
tasks = episode_data['tasks'] | ||||||
done = ~(episode_data['masks'] > 0) | ||||||
total_steps, num_actors = episode_data['tasks'].shape[:2] | ||||||
def _update_with_episode_data(self, episode_data, score_function): | ||||||
tasks = np.array(episode_data["tasks"]) | ||||||
if not self.requires_value_buffers: | ||||||
policy_logits = episode_data.action_log_dist | ||||||
done = np.array([not mask > 0 for mask in episode_data["masks"]]) | ||||||
total_steps, num_actors = tasks.shape[:2] | ||||||
|
||||||
for actor_index in range(num_actors): | ||||||
done_steps = done[:, actor_index].nonzero()[:total_steps, 0] | ||||||
|
@@ -413,25 +439,47 @@ def update_with_episode_data(self, episode_data, score_function): | |||||
if (t == 0): # if t is 0, then this done step caused a full update of previous last cycle | ||||||
continue | ||||||
|
||||||
# If there is only 1 step, we can't calculate the one-step td error | ||||||
if self.strategy == "one_step_td_error" and t - start_t <= 1: | ||||||
continue | ||||||
|
||||||
task_idx_t = tasks[start_t, actor_index].item() | ||||||
|
||||||
# Store kwargs for score function | ||||||
score_function_kwargs = {} | ||||||
score_function_kwargs["returns"] = episode_data['returns'][start_t:t, actor_index] | ||||||
score_function_kwargs["value_preds"] = episode_data['value_preds'][start_t:t, actor_index] | ||||||
if self.requires_value_buffers: | ||||||
score_function_kwargs["returns"] = episode_data.returns[start_t:t, actor_index] | ||||||
score_function_kwargs["rewards"] = episode_data.rewards[start_t:t, actor_index] | ||||||
score_function_kwargs["values"] = episode_data.values[start_t:t, actor_index] | ||||||
else: | ||||||
episode_logits = policy_logits[start_t:t, actor_index] | ||||||
score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) | ||||||
score = score_function(**score_function_kwargs) | ||||||
num_steps = len(episode_data['tasks'][start_t:t, actor_index]) | ||||||
num_steps = len(episode_data.tasks[start_t:t, actor_index]) | ||||||
# TODO: Check that task_idx_t is correct | ||||||
self.update_task_score(actor_index, task_idx_t, score, num_steps) | ||||||
|
||||||
start_t = t.item() | ||||||
if start_t < total_steps: | ||||||
# If there is only 1 step, we can't calculate the one-step td error | ||||||
if self.strategy == "one_step_td_error" and start_t == total_steps - 1: | ||||||
continue | ||||||
# TODO: Check this too | ||||||
task_idx_t = tasks[start_t, actor_index].item() | ||||||
|
||||||
# Store kwargs for score function | ||||||
score_function_kwargs = {} | ||||||
score_function_kwargs["returns"] = episode_data['returns'][start_t:, actor_index] | ||||||
score_function_kwargs["value_preds"] = episode_data['value_preds'][start_t:, actor_index] | ||||||
if self.requires_value_buffers: | ||||||
score_function_kwargs["returns"] = episode_data.returns[start_t:, actor_index] | ||||||
score_function_kwargs["rewards"] = episode_data.rewards[start_t:, actor_index] | ||||||
score_function_kwargs["values"] = episode_data.values[start_t:, actor_index] | ||||||
else: | ||||||
episode_logits = policy_logits[start_t:, actor_index] | ||||||
score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) | ||||||
|
||||||
score = score_function(**score_function_kwargs) | ||||||
self._last_score = score | ||||||
num_steps = len(episode_data['tasks'][start_t:, actor_index]) | ||||||
num_steps = len(episode_data.tasks[start_t:, actor_index]) | ||||||
self._partial_update_task_score(actor_index, task_idx_t, score, num_steps) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you reduce the code duplication between update_with_rollouts and this? Maybe create an inner function that does most of the computation, and a helper function that converts rollouts to an episode_data dictionary. So update_with_rollouts will first move data into the episode_data dictionary and then call update_with_episode_data |
||||||
|
||||||
def sample_weights(self): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
""" Test curriculum synchronization across multiple processes. """ | ||
import pytest | ||
import gym | ||
from nle.env.tasks import NetHackScore, NetHackScout, NetHackStaircase | ||
|
||
from syllabus.core import make_multiprocessing_curriculum, make_ray_curriculum | ||
|
@@ -9,14 +10,13 @@ | |
LearningProgressCurriculum, NoopCurriculum, | ||
PrioritizedLevelReplay, SequentialCurriculum, | ||
SimpleBoxCurriculum) | ||
from syllabus.tests import (create_cartpole_env, create_nethack_env, | ||
from syllabus.tests import (create_cartpole_env, create_nethack_env, get_action_value, | ||
get_test_values, run_native_multiprocess, | ||
run_ray_multiprocess, run_single_process) | ||
|
||
N_ENVS = 2 | ||
N_EPISODES = 2 | ||
|
||
|
||
nethack_env = create_nethack_env() | ||
cartpole_env = create_cartpole_env() | ||
|
||
|
@@ -29,7 +29,10 @@ | |
"get_value": get_test_values, | ||
"device": "cpu", | ||
"num_processes": N_ENVS, | ||
"num_steps": 2048 | ||
"num_steps": 2048, | ||
"robust_plr": True, | ||
"eval_envs": create_nethack_env(), | ||
"action_value_fn": get_action_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you need this to output a single action and value for now, just do something simple for testing
|
||
}), | ||
(SimpleBoxCurriculum, create_cartpole_env, (cartpole_env.task_space,), {}), | ||
(AnnealingBoxCurriculum, create_cartpole_env, (cartpole_env.task_space,), { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please reduce the code duplication here