Skip to content

Commit

Permalink
Update for new Syllabus API
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanNavillus committed Nov 26, 2024
1 parent d24f6e7 commit 35ca825
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 18 deletions.
15 changes: 6 additions & 9 deletions reinforcement_learning/clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,12 @@ def evaluate(data):
env_ids = [info["env_id"] for info in i["learner"]]

update = {
"update_type": "on_demand",
"metrics": {
"value": data.prev_value,
"next_value": value,
"rew": r,
"dones": d,
"tasks": tasks,
"env_ids": env_ids
},
"value": data.prev_value,
"next_value": value,
"rew": r,
"dones": d,
"tasks": tasks,
"env_ids": env_ids
}
data.curriculum.update(update)
data.prev_value = value
Expand Down
7 changes: 3 additions & 4 deletions reinforcement_learning/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pufferlib
import pufferlib.emulation
from pettingzoo.utils.wrappers.base_parallel import BaseParallelWrapper
from syllabus.core import PettingZooMultiProcessingSyncWrapper as SyllabusSyncWrapper
from syllabus.core import PettingZooSyncWrapper as SyllabusSyncWrapper

from syllabus_wrapper import SyllabusSeedWrapper

Expand Down Expand Up @@ -63,9 +63,8 @@ def env_creator(*args, **kwargs):
if syllabus is not None:
env = SyllabusSyncWrapper(
env,
syllabus.get_components(),
update_on_step=syllabus.requires_step_updates,
task_space=env.task_space,
env.task_space,
syllabus.components,
batch_size=8,
)

Expand Down
8 changes: 4 additions & 4 deletions syllabus_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
from collections import defaultdict
from reinforcement_learning import environment
from syllabus.task_space import TaskSpace
from syllabus.task_space import DiscreteTaskSpace
from syllabus.core.evaluator import CleanRLDiscreteEvaluator, Evaluator
from syllabus.core.task_interface import PettingZooTaskWrapper
from syllabus.curricula import SequentialCurriculum, PrioritizedLevelReplay, CentralizedPrioritizedLevelReplay
Expand Down Expand Up @@ -191,13 +191,13 @@ class SyllabusSeedWrapper(PettingZooTaskWrapper):
Wrapper to handle tasks for the Neural MMO environment.
"""

task_space = TaskSpace(200)
task_space = DiscreteTaskSpace(200)

def __init__(self, env: gym.Env):
super().__init__(env)
self.env = env

self.task_space = SyllabusTaskWrapper.task_space
self.task_space = SyllabusSeedWrapper.task_space
self.change_task(self.task_space.sample())
self._task_index = None
self.task_fn = None
Expand Down Expand Up @@ -227,7 +227,7 @@ class SyllabusTaskWrapper(PettingZooTaskWrapper):
"""

# task_space = TaskSpace((18, 200), [tuple(np.arange(18)), tuple(np.arange(200))])
task_space = TaskSpace(200)
task_space = DiscreteTaskSpace(200)

# task_space = TaskSpace((2719, 200), [tuple(np.arange(2719)), tuple(np.arange(200))])

Expand Down
2 changes: 1 addition & 1 deletion train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def train(args, env_creator, agent_creator, syllabus=None):
env_outputs = evaluate_agent(args, eval_data, env_outputs, data.wandb, data.global_step)

if syllabus is not None:
syllabus.log_metrics(data.wandb, step=data.global_step)
syllabus.log_metrics(data.wandb, [], step=data.global_step)

clean_pufferl.train(data)

Expand Down

0 comments on commit 35ca825

Please sign in to comment.