From 35ca82556d3a2003c62bfe874edf802cc8e65108 Mon Sep 17 00:00:00 2001 From: Ryan Sullivan Date: Tue, 26 Nov 2024 15:40:46 -0500 Subject: [PATCH] Update for new Syllabus API --- reinforcement_learning/clean_pufferl.py | 15 ++++++--------- reinforcement_learning/environment.py | 7 +++---- syllabus_wrapper.py | 8 ++++---- train_helper.py | 2 +- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/reinforcement_learning/clean_pufferl.py b/reinforcement_learning/clean_pufferl.py index 3ec5e65..909fd17 100644 --- a/reinforcement_learning/clean_pufferl.py +++ b/reinforcement_learning/clean_pufferl.py @@ -335,15 +335,12 @@ def evaluate(data): env_ids = [info["env_id"] for info in i["learner"]] update = { - "update_type": "on_demand", - "metrics": { - "value": data.prev_value, - "next_value": value, - "rew": r, - "dones": d, - "tasks": tasks, - "env_ids": env_ids - }, + "value": data.prev_value, + "next_value": value, + "rew": r, + "dones": d, + "tasks": tasks, + "env_ids": env_ids } data.curriculum.update(update) data.prev_value = value diff --git a/reinforcement_learning/environment.py b/reinforcement_learning/environment.py index 852fd02..d6450b2 100644 --- a/reinforcement_learning/environment.py +++ b/reinforcement_learning/environment.py @@ -4,7 +4,7 @@ import pufferlib import pufferlib.emulation from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper -from syllabus.core import PettingZooMultiProcessingSyncWrapper as SyllabusSyncWrapper +from syllabus.core import PettingZooSyncWrapper as SyllabusSyncWrapper from syllabus_wrapper import SyllabusSeedWrapper @@ -63,9 +63,8 @@ def env_creator(*args, **kwargs): if syllabus is not None: env = SyllabusSyncWrapper( env, - syllabus.get_components(), - update_on_step=syllabus.requires_step_updates, - task_space=env.task_space, + env.task_space, + syllabus.components, batch_size=8, ) diff --git a/syllabus_wrapper.py b/syllabus_wrapper.py index fad72eb..435537d 100644 --- a/syllabus_wrapper.py +++ b/syllabus_wrapper.py @@ -6,7 +6,7 @@ import copy from collections import defaultdict from reinforcement_learning import environment -from syllabus.task_space import TaskSpace +from syllabus.task_space import DiscreteTaskSpace from syllabus.core.evaluator import CleanRLDiscreteEvaluator, Evaluator from syllabus.core.task_interface import PettingZooTaskWrapper from syllabus.curricula import SequentialCurriculum, PrioritizedLevelReplay, CentralizedPrioritizedLevelReplay @@ -191,13 +191,13 @@ class SyllabusSeedWrapper(PettingZooTaskWrapper): Wrapper to handle tasks for the Neural MMO environment. """ - task_space = TaskSpace(200) + task_space = DiscreteTaskSpace(200) def __init__(self, env: gym.Env): super().__init__(env) self.env = env - self.task_space = SyllabusTaskWrapper.task_space + self.task_space = SyllabusSeedWrapper.task_space self.change_task(self.task_space.sample()) self._task_index = None self.task_fn = None @@ -227,7 +227,7 @@ class SyllabusTaskWrapper(PettingZooTaskWrapper): """ # task_space = TaskSpace((18, 200), [tuple(np.arange(18)), tuple(np.arange(200))]) - task_space = TaskSpace(200) + task_space = DiscreteTaskSpace(200) # task_space = TaskSpace((2719, 200), [tuple(np.arange(2719)), tuple(np.arange(200))]) diff --git a/train_helper.py b/train_helper.py index e9d44c4..bde697e 100644 --- a/train_helper.py +++ b/train_helper.py @@ -78,7 +78,7 @@ def train(args, env_creator, agent_creator, syllabus=None): env_outputs = evaluate_agent(args, eval_data, env_outputs, data.wandb, data.global_step) if syllabus is not None: - syllabus.log_metrics(data.wandb, step=data.global_step) + syllabus.log_metrics(data.wandb, [], step=data.global_step) clean_pufferl.train(data)