diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index d32f168..0063c80 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -490,6 +490,7 @@ def patch( preds: torch.Tensor | None = None, entropies: torch.Tensor | None = None, threshold: float = None, + patching_batch_size: int | None = None, ) -> torch.Tensor: """ tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched @@ -539,7 +540,11 @@ def patch( scores = calculate_entropies( tokens, self.entropy_model, - self.patching_batch_size, + ( + patching_batch_size + if patching_batch_size is not None + else self.patching_batch_size + ), self.device, ) if self.log_time: