diff --git a/syllabus_wrapper.py b/syllabus_wrapper.py index 682c2a6..64f57d4 100644 --- a/syllabus_wrapper.py +++ b/syllabus_wrapper.py @@ -17,7 +17,7 @@ from reinforcement_learning import environment -def make_env_creator(args, agent_module): +def make_syllabus_env_creator(args, agent_module): sample_env_creator = environment.make_env_creator( reward_wrapper_cls=agent_module.RewardWrapper, syllabus_wrapper=True ) @@ -28,7 +28,7 @@ def make_env_creator(args, agent_module): curriculum = MultiagentSharedCurriculumWrapper(curriculum, sample_env.possible_agents) curriculum = make_multiprocessing_curriculum(curriculum) - return environment.make_env_creator( + return curriculum, environment.make_env_creator( reward_wrapper_cls=agent_module.RewardWrapper, syllabus=curriculum ) diff --git a/train.py b/train.py index f2b5aca..4cbc46b 100644 --- a/train.py +++ b/train.py @@ -215,8 +215,9 @@ def update_args(args, mode=None): args = update_args(args, mode=args["mode"]) # Make default or syllabus-based env_creator + syllabus = None if args.syllabus is True: - env_creator = syllabus_wrapper.make_env_creator(args, agent_module) + syllabus, env_creator = syllabus_wrapper.make_syllabus_env_creator(args, agent_module) else: args.env.curriculum_file_path = args.curriculum env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper) @@ -232,7 +233,7 @@ def update_args(args, mode=None): args.exp_name = f"nmmo_{time.strftime('%Y%m%d_%H%M%S')}" if args.mode == "train": - train(args, env_creator, agent_creator) + train(args, env_creator, agent_creator, syllabus) exit(0) elif args.mode == "sweep": sweep(args, env_creator, agent_creator) diff --git a/train_helper.py b/train_helper.py index 90d9d12..8a36cb3 100644 --- a/train_helper.py +++ b/train_helper.py @@ -4,12 +4,10 @@ import dill import wandb - import torch import numpy as np import pufferlib.policy_pool as pp - from nmmo.render.replay_helper import FileReplayHelper from nmmo.task.task_spec import make_task_from_spec @@ -48,7 +46,7 @@ def init_wandb(args, resume=True): return wandb.init(**wandb_kwargs) -def train(args, env_creator, agent_creator): +def train(args, env_creator, agent_creator, syllabus=None): data = clean_pufferl.create( config=args.train, agent_creator=agent_creator, @@ -63,6 +61,8 @@ def train(args, env_creator, agent_creator): while not clean_pufferl.done_training(data): clean_pufferl.evaluate(data) clean_pufferl.train(data) + if syllabus is not None: + syllabus.log_metrics(data.wandb, step=data.global_step) print("Done training. Saving data...") clean_pufferl.close(data)