diff --git a/pickled_task_with_embs.pkl b/pickled_task_with_embs.pkl new file mode 100644 index 0000000..4aa383c Binary files /dev/null and b/pickled_task_with_embs.pkl differ diff --git a/train.py b/train.py index 5807af5..7eed8d5 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ import pufferlib.emulation import pufferlib.frameworks.cleanrl import pufferlib.registry.nmmo +from nmmo.task.task_api import make_team_tasks import torch import clean_pufferl @@ -61,7 +62,7 @@ help="reset on death (default: False)") parser.add_argument( "--env.num_maps", dest="num_maps", type=int, default=128, - help="number of maps to use for training (default: 1)") + help="number of maps to use for training (default: 128)") parser.add_argument( "--env.maps_path", dest="maps_path", type=str, default="maps/train/", help="path to maps to use for training (default: None)") @@ -100,7 +101,7 @@ help="number of cores to use for training (default: num_envs)") parser.add_argument( "--rollout.num_envs", dest="num_envs", type=int, default=4, - help="number of environments to use for training (default: 1)") + help="number of environments to use for training (default: 4)") parser.add_argument( "--rollout.num_buffers", dest="num_buffers", type=int, default=4, help="number of buffers to use for training (default: 4)") @@ -142,7 +143,7 @@ parser.add_argument( "--ppo.bptt_horizon", dest="bptt_horizon", type=int, default=8, help="train on bptt_horizon steps of a rollout at a time. " - "use this to reduce GPU memory (default: 16)") + "use this to reduce GPU memory (default: 8)") parser.add_argument( "--ppo.training_batch_size", @@ -198,7 +199,37 @@ ) def make_env(): - return nmmo.Env(config) + import pickle as pkl + import numpy as np + import random + + import os + print('cwd', os.getcwd()) + with open('./pickled_task_with_embs.pkl', 'rb') as f: + task_spec = pkl.load(f) + + # tasks = [d[1] for d in task_spec] + num_tasks = len(task_spec) + teams = team_helper.teams + single_task = task_spec[0] + + # make_task_fn = lambda: tasks + # task_spec_sampled =np.random.choice(task_spec, size=len(teams), replace=False) + task_spec_sampled = random.sample(task_spec, len(teams)) + tasks = make_team_tasks(teams, task_spec_sampled) + make_task_fn = lambda: tasks + + # env = nmmo.Env(config) + class NMMOTaskWrapper(nmmo.Env): + def __init__(self, config): + super().__init__(config) + + def reset(self, *args, **kwargs): + return super().reset(*args, make_task_fn=make_task_fn, **kwargs) + + env = NMMOTaskWrapper(config) + + return env # if args.model_type in ["realikun", "realikun-simplified"]: # env = NMMOTeamEnv( # config, team_helper, rewards_config, moves_only=args.moves_only)