[WIP] Fix realtime entropy patching #26
Draft
+20
−15
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a followup to the comment made here.
I want to load and train an entropy model alongside of the latent model - rather than loading it from checkpoint. The current code does not support this, so I added an additional
PatcherArgs
attribute, calledentropy_model
. Here, you can pass a Pytorch module directly, instead of a checkpoint path.I also removed a
self.output_proj
variable, because it crashes when using entropy patching. Theentropy_model_checkpoint_dir
attribute does not exist onLocalModelArgs
. Regardless, this variable is unused - and probably not necessary?Finally, I removed a logger warning that was extremely spammy. We are already gating this warning with the
BLT_SUPPRESS_ATTN_ERROR
environment variable, and I'm assuming you didn't intend on throwing the same warning with every forward pass. I'm sure this was overlooked, and you meant to remove this extra logging.Let me know if you have any questions!