diff --git a/syllabus/curricula/plr/central_plr_wrapper.py b/syllabus/curricula/plr/central_plr_wrapper.py index 7f69ea85..d9b389bf 100644 --- a/syllabus/curricula/plr/central_plr_wrapper.py +++ b/syllabus/curricula/plr/central_plr_wrapper.py @@ -133,7 +133,7 @@ def __init__( self._gae_lambda = gae_lambda self._supress_usage_warnings = suppress_usage_warnings self._task2index = {task: i for i, task in enumerate(self.tasks)} - self._task_sampler = TaskSampler(self.tasks, action_space=action_space, **task_sampler_kwargs_dict) + self._task_sampler = TaskSampler(self.tasks, task_space=task_space, action_space=action_space, **task_sampler_kwargs_dict) self._rollouts = RolloutStorage( self._num_steps, self._num_processes, diff --git a/syllabus/curricula/plr/plr_wrapper.py b/syllabus/curricula/plr/plr_wrapper.py index 9c808ddc..034aa3dd 100644 --- a/syllabus/curricula/plr/plr_wrapper.py +++ b/syllabus/curricula/plr/plr_wrapper.py @@ -149,6 +149,9 @@ class PrioritizedLevelReplay(Curriculum): gamma (float): The discount factor used to compute returns gae_lambda (float): The GAE lambda value. suppress_usage_warnings (bool): Whether to suppress warnings about improper usage. + robust_plr (bool): Option to use RobustPLR. + eval_envs: Evaluation environments for RobustPLR. + action_value_fn (callable): A function that takes an observation as input and returns an action and value. **curriculum_kwargs: Keyword arguments to pass to the curriculum. """ REQUIRES_STEP_UPDATES = True @@ -170,6 +173,9 @@ def __init__( suppress_usage_warnings=False, get_value=null, get_action_log_dist=null, + robust_plr: bool = False, # Option to use RobustPLR + eval_envs = None, + action_value_fn = None, **curriculum_kwargs, ): # Preprocess curriculum intialization args @@ -186,6 +192,9 @@ def __init__( task_sampler_kwargs_dict["num_actors"] = num_processes super().__init__(task_space, *curriculum_args, **curriculum_kwargs) + if robust_plr and eval_envs is None: + raise UsageError("RobustPLR requires evaluation environments to be provided.") + self._num_steps = num_steps # Number of steps stored in rollouts and used to update task sampler self._num_processes = num_processes # Number of parallel environments self._gamma = gamma @@ -193,8 +202,12 @@ def __init__( self._supress_usage_warnings = suppress_usage_warnings self._get_action_log_dist = get_action_log_dist self._task2index = {task: i for i, task in enumerate(self.tasks)} + self._robust_plr = robust_plr + self._eval_envs = eval_envs + self.action_value_fn = action_value_fn + + self._task_sampler = TaskSampler(self.tasks, task_space=task_space, action_space=action_space, robust_plr=robust_plr, eval_envs=eval_envs, action_value_fn = action_value_fn, **task_sampler_kwargs_dict) - self._task_sampler = TaskSampler(self.tasks, action_space=action_space, **task_sampler_kwargs_dict) self._rollouts = RolloutStorage( self._num_steps, self._num_processes, diff --git a/syllabus/curricula/plr/storage.py b/syllabus/curricula/plr/storage.py new file mode 100644 index 00000000..70ab76ae --- /dev/null +++ b/syllabus/curricula/plr/storage.py @@ -0,0 +1,69 @@ +import gymnasium as gym +import torch + +class RolloutStorage(object): + def __init__( + self, + num_steps: int, + num_processes: int, + requires_value_buffers: bool, + action_space: gym.Space = None, + ): + self._requires_value_buffers = requires_value_buffers + self.tasks = torch.zeros(num_steps, num_processes, 1, dtype=torch.int) + self.masks = torch.ones(num_steps + 1, num_processes, 1) + + if requires_value_buffers: + self.returns = torch.zeros(num_steps + 1, num_processes, 1) + self.rewards = torch.zeros(num_steps, num_processes, 1) + self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) + else: + if action_space is None: + raise ValueError( + "Action space must be provided to PLR for strategies 'policy_entropy', 'least_confidence', 'min_margin'" + ) + self.action_log_dist = torch.zeros(num_steps, num_processes, action_space.n) + + self.num_steps = num_steps + self.step = 0 + + def to(self, device): + self.masks = self.masks.to(device) + self.tasks = self.tasks.to(device) + if self._requires_value_buffers: + self.rewards = self.rewards.to(device) + self.value_preds = self.value_preds.to(device) + self.returns = self.returns.to(device) + else: + self.action_log_dist = self.action_log_dist.to(device) + + def insert(self, masks, action_log_dist=None, value_preds=None, rewards=None, tasks=None): + if self._requires_value_buffers: + assert (value_preds is not None and rewards is not None), "Selected strategy requires value_preds and rewards" + if len(rewards.shape) == 3: + rewards = rewards.squeeze(2) + self.value_preds[self.step].copy_(torch.as_tensor(value_preds)) + self.rewards[self.step].copy_(torch.as_tensor(rewards)) + self.masks[self.step + 1].copy_(torch.as_tensor(masks)) + else: + self.action_log_dist[self.step].copy_(action_log_dist) + if tasks is not None: + # assert isinstance(tasks[0], (int, torch.int32)), "Provided task must be an integer" + self.tasks[self.step].copy_(torch.as_tensor(tasks)) + self.step = (self.step + 1) % self.num_steps + + def after_update(self): + self.masks[0].copy_(self.masks[-1]) + + def compute_returns(self, next_value, gamma, gae_lambda): + assert self._requires_value_buffers, "Selected strategy does not use compute_rewards." + self.value_preds[-1] = next_value + gae = 0 + for step in reversed(range(self.rewards.size(0))): + delta = ( + self.rewards[step] + + gamma * self.value_preds[step + 1] * self.masks[step + 1] + - self.value_preds[step] + ) + gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae + self.returns[step] = gae + self.value_preds[step] \ No newline at end of file diff --git a/syllabus/curricula/plr/task_sampler.py b/syllabus/curricula/plr/task_sampler.py index c1e97a18..0907bd2a 100644 --- a/syllabus/curricula/plr/task_sampler.py +++ b/syllabus/curricula/plr/task_sampler.py @@ -1,8 +1,17 @@ # Code heavily based on the original Prioritized Level Replay implementation from https://github.com/facebookresearch/level-replay # If you use this code, please cite the above codebase and original PLR paper: https://arxiv.org/abs/2010.03934 + import gymnasium as gym import numpy as np import torch +from typing import List + +from syllabus.curricula.plr.storage import RolloutStorage +from syllabus.task_space.task_space import TaskSpace + + +def null(x): + return None class TaskSampler: @@ -23,10 +32,13 @@ class TaskSampler: staleness_coef (float): Linear interpolation weight for task staleness vs. task score. 0.0 means only use task score, 1.0 means only use staleness. staleness_transform (str): Transform to apply to task staleness. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". staleness_temperature (float): Temperature for staleness transform. Increasing temperature makes the sampling distribution more uniform. + eval_envs (List[gym.Env]): List of evaluation environments + action_value_fn (callable): A function that takes an observation as input and returns an action and value. """ def __init__( self, tasks: list, + task_space: TaskSpace, action_space: gym.spaces.Space = None, num_actors: int = 1, strategy: str = "value_l1", @@ -37,10 +49,22 @@ def __init__( rho: float = 1.0, nu: float = 0.5, alpha: float = 1.0, + num_steps: int = 256, + num_processes: int = 1, + gamma: float = 0.999, + gae_lambda: float = 0.95, staleness_coef: float = 0.1, staleness_transform: str = "power", staleness_temperature: float = 1.0, + + robust_plr: bool = False, + eval_envs = None, + action_value_fn=None, + get_value=None, + observation_space = None, + ): + self.task_space = task_space self.action_space = action_space self.tasks = tasks self.num_tasks = len(self.tasks) @@ -53,9 +77,18 @@ def __init__( self.rho = rho self.nu = nu self.alpha = float(alpha) + self.gamma = gamma + self.gae_lambda = gae_lambda self.staleness_coef = staleness_coef self.staleness_transform = staleness_transform self.staleness_temperature = staleness_temperature + self.robust_plr = robust_plr + self.eval_envs = eval_envs + self.action_value_fn = action_value_fn + self.num_steps = num_steps + self.num_processes = num_processes + self._get_values = get_value + self.observation_space = observation_space self.unseen_task_weights = np.array([1.0] * self.num_tasks) self.task_scores = np.array([0.0] * self.num_tasks, dtype=float) @@ -63,6 +96,13 @@ def __init__( self.partial_task_steps = np.zeros((num_actors, self.num_tasks), dtype=np.int64) self.task_staleness = np.array([0.0] * self.num_tasks, dtype=float) + self._robust_rollouts = RolloutStorage( + self.num_steps, + self.num_processes, + self.requires_value_buffers, + action_space=action_space, + ) + self.next_task_index = 0 # Only used for sequential strategy # Logging metrics @@ -73,28 +113,34 @@ def __init__( 'Must provide action space to PLR if using "policy_entropy", "least_confidence", or "min_margin" strategies' ) - def update_with_rollouts(self, rollouts, actor_id=None): - if self.strategy == "random": - return - - # Update with a RolloutStorage object + def _get_score_function(self): if self.strategy == "policy_entropy": - score_function = self._average_entropy + return self._average_entropy elif self.strategy == "least_confidence": - score_function = self._average_least_confidence + return self._average_least_confidence elif self.strategy == "min_margin": - score_function = self._average_min_margin + return self._average_min_margin elif self.strategy == "gae": - score_function = self._average_gae + return self._average_gae elif self.strategy == "value_l1": - score_function = self._average_value_l1 + return self._average_value_l1 elif self.strategy == "one_step_td_error": - score_function = self._one_step_td_error + return self._one_step_td_error else: raise ValueError(f"Unsupported strategy, {self.strategy}") + def update_with_rollouts(self, rollouts, actor_id=None): + if self.strategy == "random": + return + score_function = self._get_score_function() self._update_with_rollouts(rollouts, score_function, actor_index=actor_id) + def update_with_episode_data(self, episode_data): + if self.strategy == "random": + return + score_function = self._get_score_function() + self._update_with_episode_data(episode_data, score_function) + def update_task_score(self, actor_index, task_idx, score, num_steps): score = self._partial_update_task_score(actor_index, task_idx, score, num_steps, done=True) @@ -262,12 +308,54 @@ def _sample_unseen_level(self): return task + def _evaluate_unseen_level(self): + sample_weights = self.unseen_task_weights / self.unseen_task_weights.sum() + task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] + tasks = np.array(self.tasks) + task = tasks[task_idx] + + episode_data = self.evaluate_task(task, self.eval_envs, self.action_value_fn) + self.update_with_episode_data(episode_data) + + self._update_staleness(task_idx) + + def evaluate_task(self, task, env, action_value_fn): + if env is None: + raise ValueError("Environment object is None. Please ensure it is properly initialized.") + + obs, info = env.reset(new_task=task) + done = False + + while not done: + action, value = action_value_fn(obs) + + obs, rew, term, trunc, _ = env.step(action) + + task_encoded = self.task_space.encode(task) + + mask = torch.FloatTensor([0.0] if term or trunc else [1.0]) + self._robust_rollouts.insert(mask, value_preds=value, rewards=torch.Tensor([rew]), tasks=torch.Tensor([task_encoded])) + + # Check if the episode is done + if term or trunc: + done = True + + _, next_value = action_value_fn(obs) + self._robust_rollouts.compute_returns(next_value, self.gamma, self.gae_lambda) + return { + "tasks": self._robust_rollouts.tasks, + "masks": self._robust_rollouts.masks, + "rewards": self._robust_rollouts.rewards, + "value_preds": self._robust_rollouts.value_preds, + "returns": self._robust_rollouts.returns, + } + def sample(self, strategy=None): if not strategy: strategy = self.strategy if strategy == "random": - task_idx = np.random.choice(range((self.num_tasks))) + task_idx = np.random.choice(range(self.num_tasks)) task = self.tasks[task_idx] return task @@ -287,15 +375,90 @@ def sample(self, strategy=None): return self._sample_replay_level() # Otherwise, sample a new level - return self._sample_unseen_level() + if self.robust_plr: + self._evaluate_unseen_level() + self._robust_rollouts.after_update() + self.after_update() + return self.sample(strategy=strategy) + else: + return self._sample_unseen_level() elif self.replay_schedule == "proportionate": if proportion_seen >= self.rho and np.random.rand() < proportion_seen: return self._sample_replay_level() else: - return self._sample_unseen_level() + if self.robust_plr: + self._evaluate_unseen_level() + self._robust_rollouts.after_update() + self.after_update() + return self.sample(strategy=strategy) + else: + return self._sample_unseen_level() else: - raise NotImplementedError(f"Unsupported replay schedule: {self.replay_schedule}. Must be 'fixed' or 'proportionate'.") + raise NotImplementedError( + f"Unsupported replay schedule: {self.replay_schedule}. Must be 'fixed' or 'proportionate'.") + + def _update_with_episode_data(self, episode_data, score_function): + tasks = episode_data["tasks"] + if not self.requires_value_buffers: + policy_logits = episode_data.action_log_dist + done = ~(episode_data["masks"] > 0) + + total_steps, num_actors = tasks.shape[:2] + + for actor_index in range(num_actors): + done_steps = done[:, actor_index].nonzero()[:total_steps, 0] + start_t = 0 + + for t in done_steps: + if not start_t < total_steps: + break + + if (t == 0): # if t is 0, then this done step caused a full update of previous last cycle + continue + + # If there is only 1 step, we can't calculate the one-step td error + if self.strategy == "one_step_td_error" and t - start_t <= 1: + continue + + task_idx_t = tasks[start_t, actor_index].item() + + # Store kwargs for score function + score_function_kwargs = {} + if self.requires_value_buffers: + score_function_kwargs["returns"] = episode_data["returns"][start_t:t, actor_index] + score_function_kwargs["rewards"] = episode_data["rewards"][start_t:t, actor_index] + score_function_kwargs["value_preds"] = episode_data["value_preds"][start_t:t, actor_index] + else: + episode_logits = policy_logits[start_t:t, actor_index] + score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) + score = score_function(**score_function_kwargs) + num_steps = len(episode_data["tasks"][start_t:t, actor_index]) + # TODO: Check that task_idx_t is correct + self.update_task_score(actor_index, task_idx_t, score, num_steps) + + start_t = t.item() + if start_t < total_steps: + # If there is only 1 step, we can't calculate the one-step td error + if self.strategy == "one_step_td_error" and start_t == total_steps - 1: + continue + # TODO: Check this too + task_idx_t = tasks[start_t, actor_index].item() + + # Store kwargs for score function + score_function_kwargs = {} + if self.requires_value_buffers: + score_function_kwargs["returns"] = episode_data["returns"][start_t:, actor_index] + score_function_kwargs["rewards"] = episode_data["rewards"][start_t:, actor_index] + score_function_kwargs["value_preds"] = episode_data["value_preds"][start_t:, actor_index] + else: + episode_logits = policy_logits[start_t:, actor_index] + score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) + + score = score_function(**score_function_kwargs) + self._last_score = score + num_steps = len(episode_data["tasks"][start_t:, actor_index]) + self._partial_update_task_score(actor_index, task_idx_t, score, num_steps) def sample_weights(self): weights = self._score_transform(self.score_transform, self.temperature, self.task_scores) diff --git a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py index 5aebd671..49464bb8 100644 --- a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py +++ b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py @@ -185,7 +185,7 @@ def level_replay_evaluate( return mean_returns, stddev_returns, normalized_mean_returns -def make_value_fn(): +def make_value_fn(agent): def get_value(obs): obs = np.array(obs) with torch.no_grad(): @@ -193,6 +193,15 @@ def get_value(obs): return get_value +def make_action_value_fn(agent): + def get_action_value(obs): + obs = np.array(obs[None,:]) + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(torch.Tensor(obs).to(device)) + return action.cpu().numpy(), value + return get_action_value + + if __name__ == "__main__": args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" @@ -224,7 +233,15 @@ def get_value(obs): torch.backends.cudnn.deterministic = args.torch_deterministic device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") - print("Device:", device) + + print("Creating agent") + agent = ProcgenAgent( + (64, 64, 3), + 15, + arch="large", + base_kwargs={'recurrent': False, 'hidden_size': 256} + ).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) # Curriculum setup curriculum = None @@ -236,6 +253,9 @@ def get_value(obs): # Intialize Curriculum Method if args.curriculum_method == "plr": print("Using prioritized level replay.") + + plr_eval_env = make_env(args.env_id, args.seed, num_levels=200)() + plr_eval_env = ProcgenTaskWrapper(plr_eval_env, args.env_id, seed=args.seed) curriculum = PrioritizedLevelReplay( sample_env.task_space, sample_env.observation_space, @@ -244,7 +264,10 @@ def get_value(obs): gamma=args.gamma, gae_lambda=args.gae_lambda, task_sampler_kwargs_dict={"strategy": "value_l1"}, - get_value=make_value_fn(), + get_value=make_value_fn(agent), + robust_plr=True, + eval_envs=plr_eval_env, + action_value_fn=make_action_value_fn(agent), ) elif args.curriculum_method == "dr": print("Using domain randomization.") @@ -282,16 +305,6 @@ def get_value(obs): ) envs = wrap_vecenv(envs) - assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" - print("Creating agent") - agent = ProcgenAgent( - envs.single_observation_space.shape, - envs.single_action_space.n, - arch="large", - base_kwargs={'recurrent': False, 'hidden_size': 256} - ).to(device) - optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) - # ALGO Logic: Storage setup obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) diff --git a/syllabus/tests/utils.py b/syllabus/tests/utils.py index 98bac823..fe49f773 100644 --- a/syllabus/tests/utils.py +++ b/syllabus/tests/utils.py @@ -186,10 +186,17 @@ def run_ray_multiprocess(env_fn, env_args=(), env_kwargs={}, curriculum=None, nu ray.kill(curriculum.curriculum) return ray_speed + def get_test_values(x): return torch.unsqueeze(torch.Tensor(np.array([0] * len(x))), -1) +def get_action_value(obs): + action = 0 + value = 0 + return action, value + + # Sync Test Environment def create_synctest_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): env = SyncTestEnv(*env_args, **env_kwargs) @@ -223,6 +230,7 @@ def create_nethack_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): warnings.warn("Unable to import nle.") env = NetHackScore(*env_args, **env_kwargs) + env = GymV21CompatibilityV0(env=env) env = NethackTaskWrapper(env) if type == "queue": diff --git a/tests/multiprocessing_smoke_tests.py b/tests/multiprocessing_smoke_tests.py index b7881795..9b26fb8d 100644 --- a/tests/multiprocessing_smoke_tests.py +++ b/tests/multiprocessing_smoke_tests.py @@ -1,5 +1,6 @@ """ Test curriculum synchronization across multiple processes. """ import pytest +import gym from nle.env.tasks import NetHackScore, NetHackScout, NetHackStaircase from syllabus.core import make_multiprocessing_curriculum, make_ray_curriculum @@ -9,14 +10,13 @@ LearningProgressCurriculum, NoopCurriculum, PrioritizedLevelReplay, SequentialCurriculum, SimpleBoxCurriculum) -from syllabus.tests import (create_cartpole_env, create_nethack_env, +from syllabus.tests import (create_cartpole_env, create_nethack_env, get_action_value, get_test_values, run_native_multiprocess, run_ray_multiprocess, run_single_process) N_ENVS = 2 N_EPISODES = 2 - nethack_env = create_nethack_env() cartpole_env = create_cartpole_env() @@ -29,7 +29,10 @@ "get_value": get_test_values, "device": "cpu", "num_processes": N_ENVS, - "num_steps": 2048 + "num_steps": 2048, + "robust_plr": True, + "eval_envs": create_nethack_env(), + "action_value_fn": get_action_value }), (SimpleBoxCurriculum, create_cartpole_env, (cartpole_env.task_space,), {}), (AnnealingBoxCurriculum, create_cartpole_env, (cartpole_env.task_space,), {