-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sb3 #38
base: main
Are you sure you want to change the base?
sb3 #38
Conversation
https://wandb.ai/h975894552/Syllabus/runs/z24gccwf/overview?nw=nwuserh975894552 Data for this model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments. Please make the requested changes to simplify the PR a bit, and let me know if you have any questions about the callbacks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove any changes to this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove any changes to this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove any changes to this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove any changes to this file
) | ||
env = openai_gym.make(f"procgen-{env_id}-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) | ||
env = GymV21CompatibilityV0(env=env) | ||
components = MultiProcessingComponents(task_queue=task_queue, update_queue=update_queue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
components = MultiProcessingComponents(task_queue=task_queue, update_queue=update_queue) | |
components = curriculum.get_components() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this file git rm --cached syllabus/examples/training_scripts/wandb/run-20240423_020001-cymykoqj/files/events.out.tfevents.1713852002.WenranLaoGong.157219.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this file git rm --cached syllabus/examples/training_scripts/wandb/run-20240423_020001-cymykoqj/files/events.out.tfevents.1713852002.WenranLaoGong.157219.0
profiling_results.prof
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this file git rm --cached profiling_results.prof
model.policy.train() | ||
return mean_returns, stddev_returns, normalized_mean_returns | ||
|
||
|
||
class CustomCallback(BaseCallback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look at the documentation here to change the behavior of the callback https://stable-baselines3.readthedocs.io/en/v1.0/guide/callbacks.html
mean_eval_returns, _, _ = level_replay_evaluate_sb3(args.env_id, model, args.num_eval_episodes, num_levels=0) | ||
writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, self.global_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code should only be run once every update. There is a different callback method _on_training_end
that you should probably use.
If you need access to any data from training, try printing out self.locals or self.globals from within the callback method to see what is available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the hyperparameters you included, but I didn't review any that you excluded. I'll revisit that later
""" | ||
return True | ||
|
||
def _on_rollout_end(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, but I think you can put this function in the CustomCallback rather than creating 2 separate ones
return True | ||
|
||
def _on_rollout_end(self) -> None: | ||
if self.num_timesteps % self.eval_freq == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't necessary, we should just evaluate every time this function is called. It should happen every 16,000 steps, but try running it and make sure that's what happens
def wrap_vecenv(vecenv): | ||
vecenv.is_vector_env = True | ||
vecenv = VecMonitor(venv=vecenv, filename=None) | ||
vecenv = VecNormalize(venv=vecenv, norm_obs=False, norm_reward=True) | ||
vecenv = VecNormalize(venv=vecenv, norm_obs=False, norm_reward=True, training=False) | ||
return vecenv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is used for both eval and training, we should probably pass training as an argument to wrap_vecenv, so that training=True for the training envs right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these changes
n_epochs=3, | ||
clip_range_vf=0.2, | ||
ent_coef=0.01, | ||
batch_size=256 * 64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_size=256 * 64, | |
batch_size=2048 |
batch_size is actually the minibatch size. We want 8 batches for 25*64 steps, so 2048 steps per minibatch
|
||
print("Creating model") | ||
model = PPO( | ||
"CnnPolicy", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're going to need to find a way to replace this with the ProcgenAgent model. https://stable-baselines3.readthedocs.io/en/v1.0/guide/custom_policy.html
Take a look at the advanced example https://stable-baselines3.readthedocs.io/en/v1.0/guide/custom_policy.html#advanced-example
I think if you replace the CustomNetwork with our Policy (the parent class of ProcgenAgent) then the code they have here might just work out of the box.
class Policy(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this works, but it you should try this. It would be a lot simpler this way
from syllabus.examples.models.procgen_model import Policy
class CustomActorCriticPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
*args,
**kwargs,
):
super(CustomActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = Policy(...)
This is a small change to the documentation here https://stable-baselines3.readthedocs.io/en/v1.0/guide/custom_policy.html#advanced-example
return value, action_log_probs, dist_entropy | ||
|
||
|
||
class Sb3ProcgenAgent(CustomPolicy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think SB3's model will exclusively call forward
, so this class isn't necessary
def get_value(self, input): | ||
value, _, _ = self.network(input) | ||
return value | ||
|
||
def evaluate_actions(self, input, rnn_hxs, masks, action): | ||
value, actor_features = self.network(input, rnn_hxs, masks) | ||
dist = self.dist(actor_features) | ||
|
||
action_log_probs = dist.log_prob(action) | ||
dist_entropy = dist.entropy().mean() | ||
return value, action_log_probs, dist_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below, I'm not sure you need to add these methods
No description provided.