Skip to content

Commit

Permalink
better sliding window
Browse files Browse the repository at this point in the history
  • Loading branch information
austinbrown5 committed Sep 3, 2024
1 parent 3f19c7e commit 0249ab9
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions mimir/attacks/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,17 @@ def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, toke
text_ids = target_encodings.input_ids.to(model.device)

max_length = model.max_length
total_length = prefix_ids.size(1) + text_ids.size(1)

if prefix_ids.size(1) >= max_length:
raise ValueError("Prefix length exceeds or equals the model's maximum context window.")
stride = model.stride

labels = torch.cat((prefix_ids, text_ids), dim=1)

total_loss = 0
with torch.no_grad():
for i in range(0, labels.size(1), stride):
begin_loc = max(i + stride - max_length, 0)
end_loc = min(i + stride, labels.size(1))
for i in range(0, labels.size(1), max_length):
begin_loc = max(i - max_length, 0)
end_loc = min(i, labels.size(1))
trg_len = end_loc - i
input_ids = labels[:, begin_loc:end_loc].to(model.device)
target_ids = input_ids.clone()
Expand Down

0 comments on commit 0249ab9

Please sign in to comment.