Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying make_env function to enable task information to propagate #41

Open
wants to merge 5 commits into
base: 2.0
Choose a base branch
from
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pufferlib.emulation
import pufferlib.frameworks.cleanrl
import pufferlib.registry.nmmo
from nmmo.task.task_api import make_team_tasks
import torch

import clean_pufferl
Expand All @@ -30,7 +31,7 @@
help="path to model to load (default: None)")
parser.add_argument(
"--model.type",
dest="model_type", type=str, default="realikun",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid changing default args for testing when possible. Quick check for this is to just run a diff on the train script

dest="model_type", type=str, default="basic",
help="model type (default: realikun)")

parser.add_argument(
Expand Down Expand Up @@ -60,7 +61,7 @@
action="store_true", default=False,
help="reset on death (default: False)")
parser.add_argument(
"--env.num_maps", dest="num_maps", type=int, default=128,
nikhilpinnaparaju marked this conversation as resolved.
Show resolved Hide resolved
"--env.num_maps", dest="num_maps", type=int, default=1,
help="number of maps to use for training (default: 1)")
parser.add_argument(
"--env.maps_path", dest="maps_path", type=str, default="maps/train/",
Expand Down Expand Up @@ -99,7 +100,7 @@
"--rollout.num_cores", dest="num_cores", type=int, default=None,
help="number of cores to use for training (default: num_envs)")
parser.add_argument(
"--rollout.num_envs", dest="num_envs", type=int, default=4,
nikhilpinnaparaju marked this conversation as resolved.
Show resolved Hide resolved
"--rollout.num_envs", dest="num_envs", type=int, default=1,
help="number of environments to use for training (default: 1)")
parser.add_argument(
"--rollout.num_buffers", dest="num_buffers", type=int, default=4,
Expand Down Expand Up @@ -140,7 +141,7 @@
help="wandb entity name (default: None)")

parser.add_argument(
"--ppo.bptt_horizon", dest="bptt_horizon", type=int, default=8,
"--ppo.bptt_horizon", dest="bptt_horizon", type=int, default=16,
help="train on bptt_horizon steps of a rollout at a time. "
"use this to reduce GPU memory (default: 16)")

Expand Down Expand Up @@ -198,7 +199,37 @@
)

def make_env():
return nmmo.Env(config)
import pickle as pkl
import numpy as np
import random

import os
print('cwd', os.getcwd())
with open('./pickled_task_with_embs.pkl', 'rb') as f:
nikhilpinnaparaju marked this conversation as resolved.
Show resolved Hide resolved
task_spec = pkl.load(f)

# tasks = [d[1] for d in task_spec]
num_tasks = len(task_spec)
teams = team_helper.teams
single_task = task_spec[0]

# make_task_fn = lambda: tasks
# task_spec_sampled =np.random.choice(task_spec, size=len(teams), replace=False)
task_spec_sampled = random.sample(task_spec, len(teams))
tasks = make_team_tasks(teams, task_spec_sampled)
make_task_fn = lambda: tasks

# env = nmmo.Env(config)
class MyNMMO(nmmo.Env):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I said MyNMMO as an indicator to override the class name -- call it TestTaskNMMO or something

def __init__(self, config):
super().__init__(config)

def reset(self, *args, **kwargs):
return super().reset(*args, make_task_fn=make_task_fn, **kwargs)

env = MyNMMO(config)

return env
# if args.model_type in ["realikun", "realikun-simplified"]:
# env = NMMOTeamEnv(
# config, team_helper, rewards_config, moves_only=args.moves_only)
Expand Down