Skip to content

Commit

Permalink
Eval and logging improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanNavillus committed Nov 13, 2024
1 parent faad8e0 commit d24f6e7
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 16 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ train:

num_envs: 16
envs_per_worker: 1
envs_per_batch: 6
envs_per_batch: 8
env_pool: True
verbose: True
data_dir: runs
Expand Down
19 changes: 10 additions & 9 deletions reinforcement_learning/clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,16 @@ def create(
resume_state = {}
path = os.path.join(config.data_dir, exp_name)
if False and os.path.exists(path):
trainer_path = os.path.join(path, "trainer_state.pt")
resume_state = torch.load(trainer_path)
model_path = os.path.join(path, resume_state["model_name"])
agent = torch.load(model_path, map_location=device)
print(
f'Resumed from update {resume_state["update"]} '
f'with policy {resume_state["model_name"]}'
)
else:
pass
# trainer_path = os.path.join(path, "trainer_state.pt")
# resume_state = torch.load(trainer_path)
# model_path = os.path.join(path, resume_state["model_name"])
# agent = torch.load(model_path, map_location=device)
# print(
# f'Resumed from update {resume_state["update"]} '
# f'with policy {resume_state["model_name"]}'
# )
elif not eval_mode:
agent = pufferlib.emulation.make_object(
agent, agent_creator, [pool.driver_env], agent_kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion reinforcement_learning/stat_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _process_stats_and_early_stop(self, agent_id, reward, terminated, truncated,
for key, val in list(achieved.items()) + list(performed.items()):
info["stats"][key] = float(val)

if self._stat_prefix:
if self._stat_prefix is not None:
info = {self._stat_prefix: info}

return truncated, info
Expand Down
143 changes: 138 additions & 5 deletions train_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections import defaultdict
import copy
import importlib
import os
import time
import logging
Expand All @@ -11,7 +14,7 @@
from nmmo.render.replay_helper import FileReplayHelper
from nmmo.task.task_spec import make_task_from_spec
from syllabus.curricula import PrioritizedLevelReplay
from reinforcement_learning import clean_pufferl
from reinforcement_learning import clean_pufferl, environment

# Related to torch.use_deterministic_algorithms(True)
# See also https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
Expand Down Expand Up @@ -39,7 +42,7 @@ def init_wandb(args, resume=True):
"name": args.exp_name,
"monitor_gym": True,
"save_code": True,
"resume": resume,
"resume": False,
}
if args.wandb.group is not None:
wandb_kwargs["group"] = args.wandb.group
Expand All @@ -64,17 +67,146 @@ def train(args, env_creator, agent_creator, syllabus=None):
syllabus.curriculum.curriculum.evaluator.set_agent(data.agent)
syllabus.start()

eval_args = copy.deepcopy(args)
eval_data, env_outputs = setup_eval(eval_args, data.agent)

while not clean_pufferl.done_training(data):
clean_pufferl.evaluate(data)
clean_pufferl.train(data)

# Evaluate on test seeds
print("Evaluating")
env_outputs = evaluate_agent(args, eval_data, env_outputs, data.wandb, data.global_step)

if syllabus is not None:
syllabus.log_metrics(data.wandb, step=None)
syllabus.log_metrics(data.wandb, step=data.global_step)

clean_pufferl.train(data)

print("Done training. Saving data...")
clean_pufferl.close(data)
clean_pufferl.close(eval_data)
print("Run complete.")


def unroll_nested_dict(d):
if not isinstance(d, dict):
return d

for k, v in d.items():
if isinstance(v, dict):
for k2, v2 in unroll_nested_dict(v):
yield f"{k}/{k2}", v2
else:
yield k, v


def setup_eval(args, agent):
# Set the train config for replay
args.train.num_envs = 8
args.train.envs_per_batch = 8
args.train.envs_per_worker = 1
args.track = True
args.reward_wrapper.stat_prefix = "eval"
# Disable env pool - see the comment about next_lstm_state in clean_pufferl.evaluate()
args.train.env_pool = False
args.env.resilient_population = 0
# Set the reward wrapper for replay
args.reward_wrapper.eval_mode = True
args.reward_wrapper.early_stop_agent_num = 0

# # Use the policy pool helper functions to create kernel (policy-agent mapping)
# args.train.pool_kernel = pp.create_kernel(
# args.env.num_agents, len(policies), shuffle_with_seed=args.train.seed
# )
agent_module = importlib.import_module(f"agent_zoo.{args['agent']}")

env_creator = environment.make_env_creator(reward_wrapper_cls=agent_module.RewardWrapper)
data = clean_pufferl.create(
config=args.train,
agent=agent,
agent_kwargs={"args": args},
env_creator=env_creator,
env_creator_kwargs={"env": args.env, "reward_wrapper": args.reward_wrapper},
eval_mode=True,
eval_model_path=os.path.join(args.train.data_dir, args.exp_name, "eval"),
policy_selector=pp.AllPolicySelector(args.train.seed),
exp_name=args.exp_name,
track=args.track,
vectorization=args.vectorization,
)
env_outputs = data.pool.recv() # This resets the env
return data, env_outputs


def evaluate_agent(args, data, env_outputs, train_wandb, global_step):
o, r, d, t, i, env_id, mask = env_outputs

# Evaluate agent
eval_returns = []
ep_returns = torch.zeros(8 * args.env.num_agents)
while len(eval_returns) <= 8 * args.env.num_agents:
with torch.no_grad():
o = torch.as_tensor(o)
r = torch.as_tensor(r).float().to(data.device).view(-1)
d = torch.as_tensor(d).float().to(data.device).view(-1)

# env_pool must be false for the lstm to work
next_lstm_state = data.next_lstm_state
if next_lstm_state is not None:
next_lstm_state = (
next_lstm_state[0][:, env_id],
next_lstm_state[1][:, env_id],
)

actions, logprob, value, next_lstm_state = data.policy_pool.forwards(
o.to(data.device), next_lstm_state
)

if next_lstm_state is not None:
h, c = next_lstm_state
data.next_lstm_state[0][:, env_id] = h
data.next_lstm_state[1][:, env_id] = c

value = value.flatten()

data.pool.send(actions.cpu().numpy())
env_outputs = data.pool.recv()
o, r, d, t, i, env_id, mask = env_outputs
ep_returns += r
for index, done in enumerate(d):
if done:
eval_returns.append(ep_returns[index].item())
ep_returns[index] = 0
data.stats = {}

# Get stats only from the learner
i = data.policy_pool.update_scores(i, "return")
# TODO: Update this for policy pool
for ii, ee in zip(i["learner"], env_id):
ii["env_id"] = ee
infos = defaultdict(lambda: defaultdict(list))
for policy_name, policy_i in i.items():
for agent_i in policy_i:
for name, dat in unroll_nested_dict(agent_i):
infos[policy_name][name].append(dat)

for k, v in infos["learner"].items():
try: # TODO: Better checks on log data types
# Skip the unnecessary info from the stats
if not any(skip in k for skip in ["curriculum/Task_", "env_id"]):
data.stats[k] = np.mean(v)
except:
continue

# print("logging", global_step)
train_wandb.log({
"global_step": global_step,
**{f"{k}": v for k, v in data.stats.items()}
})

return env_outputs


def sweep(args, env_creator, agent_creator):
sweep_id = wandb.sweep(sweep=args.sweep, project=args.wandb.project)

Expand All @@ -95,7 +227,7 @@ def main():
wandb.agent(sweep_id, main, count=20)


def generate_replay(args, env_creator, agent_creator, stop_when_all_complete_task=True, seed=None):
def generate_replay(args, env_creator, agent_creator, stop_when_all_complete_task=True, seed=None, agent=None):
assert args.eval_model_path is not None, "eval_model_path must be set for replay generation"
policies = pp.get_policy_names(args.eval_model_path)
assert len(policies) > 0, "No policies found in eval_model_path"
Expand Down Expand Up @@ -125,6 +257,7 @@ def generate_replay(args, env_creator, agent_creator, stop_when_all_complete_tas
data = clean_pufferl.create(
config=args.train,
agent_creator=agent_creator,
agent=agent,
agent_kwargs={"args": args},
env_creator=env_creator,
env_creator_kwargs={"env": args.env, "reward_wrapper": args.reward_wrapper},
Expand Down

0 comments on commit d24f6e7

Please sign in to comment.