Skip to content

Commit

Permalink
Merge pull request #4 from Cyrilvallez/dev
Browse files Browse the repository at this point in the history
Add perplexity computation
  • Loading branch information
Cyrilvallez authored Dec 21, 2023
2 parents cb371a0 + 4d36cb5 commit 51e5a3a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 2 deletions.
2 changes: 1 addition & 1 deletion memory_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import numpy as np

from textwiz import HFModel, loader, warnings_suppressor, utils
from .textwiz import HFModel, loader, warnings_suppressor, utils

# Remove warning when tokenizing sequences longer than expected: we know we are doing it!
logger = logging.getLogger('transformers.tokenization_utils_base')
Expand Down
2 changes: 1 addition & 1 deletion memory_estimator_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

import textwiz
from . import textwiz


def dispatch_jobs_srun(gpu_footprints: list[int], num_gpus: int, commands: list[str], cpus_per_task: int | list[int] = 2,
Expand Down
78 changes: 78 additions & 0 deletions textwiz/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,84 @@ def truncate_conversation(self, conversation: GenericConversation, max_new_token
input_length = input.shape[-1]

return new_conv


def perplexity(self, text: str, stride = 512) -> float:
"""Compute the perplexity of given `text`. If the number of tokens is larger than the maximum context size,
use a sliding window with given `stride`. That is, we will move the input of `stride` tokens at each iteration.
Thus, the model will always have a context of `max_context_size - stride` in order to compute the negative
log-likelihood of `stride` new tokens after the first iteration. Small `stride` will give better results but
will require more forward passes.
Parameters
----------
text : str
Text for which to compute perplexity.
stride : int, optional
Sliding window parameter, by default 512
Returns
-------
float
The perplexity of `text` given the current model.
"""

encoding = self.tokenizer(text, return_tensors='pt')

max_length = self.get_context_size()
seq_len = encoding.input_ids.shape[-1]

if stride >= max_length:
raise RuntimeError('The stride should be lower than the model maximum context size.')

# Use reduction='sum' instead of 'mean' to compute correct mean at the end (the mean of the mean is incorrect)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
loss = torch.tensor(0., requires_grad=False, device=self.input_device)

prev_end_loc = 0
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + max_length, seq_len)
# Will always be equal to stride except on last iteration
target_length = end_loc - prev_end_loc

# Compute inputs and targets
input_ids = encoding.input_ids[:, begin_loc:end_loc].to(self.input_device)
target_ids = input_ids.clone()

# This will mask the (max_length - stride) already processed targets in the loss
target_ids[:, :-target_length] = -100

# Remove first target as we cannot compute the probability distribution for the first token of the input.
# This is not an issue since for first iteration the first token is <BOS>, and it is masked for other iterations
target_ids = target_ids[:, 1:]
# Remove batch dimension of size 1 (empty)
target_ids = target_ids.squeeze(0)

with torch.no_grad():
outputs = self.model(input_ids)

# Extract the logits for all tokens except the last one (we do not care about the probability
# distribution of what would be the new token if we were performing auto-regresive generation)
logits = outputs.logits[:, :-1, :]
# Remove batch dimension of size 1 (empty)
logits = logits.squeeze(0)

# Logits now have dimension (len(input_ids)-1, vocab_size). This correspond to the logit distribution
# for each token given the previous ones. Instead of applying a softmax, taking the probability
# corresponding to the input token, and summing, we can directly use the CrossEntropyLoss as a trick.
# That is, we see it as the loss for a problem with C=vocab_size classes, and an "artificial batch"
# of size len(input_ids)-1
loss += criterion(logits, target_ids)

prev_end_loc = end_loc
if end_loc == seq_len:
break

# Don't forget to apply the exponential after dividing by the total size of the sequence
perplexity_output = torch.exp(loss / (seq_len-1))

return perplexity_output.item()




0 comments on commit 51e5a3a

Please sign in to comment.