Skip to content

Commit

Permalink
add syllabus wandb logging
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Apr 18, 2024
1 parent acf7596 commit 89859d5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions syllabus_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from reinforcement_learning import environment


def make_env_creator(args, agent_module):
def make_syllabus_env_creator(args, agent_module):
sample_env_creator = environment.make_env_creator(
reward_wrapper_cls=agent_module.RewardWrapper, syllabus_wrapper=True
)
Expand All @@ -28,7 +28,7 @@ def make_env_creator(args, agent_module):
curriculum = MultiagentSharedCurriculumWrapper(curriculum, sample_env.possible_agents)
curriculum = make_multiprocessing_curriculum(curriculum)

return environment.make_env_creator(
return curriculum, environment.make_env_creator(
reward_wrapper_cls=agent_module.RewardWrapper, syllabus=curriculum
)

Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ def update_args(args, mode=None):
args = update_args(args, mode=args["mode"])

# Make default or syllabus-based env_creator
syllabus = None
if args.syllabus is True:
env_creator = syllabus_wrapper.make_env_creator(args, agent_module)
syllabus, env_creator = syllabus_wrapper.make_syllabus_env_creator(args, agent_module)
else:
args.env.curriculum_file_path = args.curriculum
env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper)
Expand All @@ -232,7 +233,7 @@ def update_args(args, mode=None):
args.exp_name = f"nmmo_{time.strftime('%Y%m%d_%H%M%S')}"

if args.mode == "train":
train(args, env_creator, agent_creator)
train(args, env_creator, agent_creator, syllabus)
exit(0)
elif args.mode == "sweep":
sweep(args, env_creator, agent_creator)
Expand Down
6 changes: 3 additions & 3 deletions train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@

import dill
import wandb

import torch
import numpy as np

import pufferlib.policy_pool as pp

from nmmo.render.replay_helper import FileReplayHelper
from nmmo.task.task_spec import make_task_from_spec

Expand Down Expand Up @@ -48,7 +46,7 @@ def init_wandb(args, resume=True):
return wandb.init(**wandb_kwargs)


def train(args, env_creator, agent_creator):
def train(args, env_creator, agent_creator, syllabus=None):
data = clean_pufferl.create(
config=args.train,
agent_creator=agent_creator,
Expand All @@ -63,6 +61,8 @@ def train(args, env_creator, agent_creator):
while not clean_pufferl.done_training(data):
clean_pufferl.evaluate(data)
clean_pufferl.train(data)
if syllabus is not None:
syllabus.log_metrics(data.wandb, step=data.global_step)

print("Done training. Saving data...")
clean_pufferl.close(data)
Expand Down

0 comments on commit 89859d5

Please sign in to comment.