From 0249ab991ef26902fd88729a957d2b0274babb9a Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Tue, 3 Sep 2024 18:23:29 -0400 Subject: [PATCH] better sliding window --- mimir/attacks/recall.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index be890d0..38f9fa4 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -81,19 +81,17 @@ def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, toke text_ids = target_encodings.input_ids.to(model.device) max_length = model.max_length - total_length = prefix_ids.size(1) + text_ids.size(1) if prefix_ids.size(1) >= max_length: raise ValueError("Prefix length exceeds or equals the model's maximum context window.") - stride = model.stride labels = torch.cat((prefix_ids, text_ids), dim=1) total_loss = 0 with torch.no_grad(): - for i in range(0, labels.size(1), stride): - begin_loc = max(i + stride - max_length, 0) - end_loc = min(i + stride, labels.size(1)) + for i in range(0, labels.size(1), max_length): + begin_loc = max(i - max_length, 0) + end_loc = min(i, labels.size(1)) trg_len = end_loc - i input_ids = labels[:, begin_loc:end_loc].to(model.device) target_ids = input_ids.clone()