Skip to content

Commit

Permalink
allow patch_batch_size to be adjusted in the forward() method
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorrent committed Jan 18, 2025
1 parent 175fce6 commit cff0dcb
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion bytelatent/data/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def patch(
preds: torch.Tensor | None = None,
entropies: torch.Tensor | None = None,
threshold: float = None,
patching_batch_size: int | None = None,
) -> torch.Tensor:
"""
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
Expand Down Expand Up @@ -539,7 +540,11 @@ def patch(
scores = calculate_entropies(
tokens,
self.entropy_model,
self.patching_batch_size,
(
patching_batch_size
if patching_batch_size is not None
else self.patching_batch_size
),
self.device,
)
if self.log_time:
Expand Down

0 comments on commit cff0dcb

Please sign in to comment.