Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix realtime entropy patching #26

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Vectorrent
Copy link
Contributor

This is a followup to the comment made here.

I want to load and train an entropy model alongside of the latent model - rather than loading it from checkpoint. The current code does not support this, so I added an additional PatcherArgs attribute, called entropy_model. Here, you can pass a Pytorch module directly, instead of a checkpoint path.

I also removed a self.output_proj variable, because it crashes when using entropy patching. The entropy_model_checkpoint_dir attribute does not exist on LocalModelArgs. Regardless, this variable is unused - and probably not necessary?

Finally, I removed a logger warning that was extremely spammy. We are already gating this warning with the BLT_SUPPRESS_ATTN_ERROR environment variable, and I'm assuming you didn't intend on throwing the same warning with every forward pass. I'm sure this was overlooked, and you meant to remove this extra logging.

Let me know if you have any questions!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 18, 2025
@Vectorrent Vectorrent changed the title Fix realtime entropy patching [WIP] Fix realtime entropy patching Jan 18, 2025
@Vectorrent
Copy link
Contributor Author

Vectorrent commented Jan 18, 2025

I noticed that the Patcher expects a static batch size to be set, but my custom model is more dynamic than that - and often uses a variable batch size. Thus, I need to adjust the batch size in the patcher in realtime, during training. I made that possible with an override in the forward method.

@Vectorrent Vectorrent marked this pull request as draft January 18, 2025 01:48
@Vectorrent
Copy link
Contributor Author

I reverted some of the Patcher code to it's original style, since I realized I could just use the entropies argument to handle the batch size stuff.

I also fixed a warning here, since passing entropies in that way was causing this:

bytelatent/data/patcher.py:535: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scores = torch.tensor(entropies, dtype=torch.float32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants