Skip to content

Commit

Permalink
Add loss generating token counts (#1610)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 27, 2024
1 parent e11a43c commit 874c30a
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 50 deletions.
5 changes: 3 additions & 2 deletions llmfoundry/callbacks/loss_perp_v_len_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchmetrics import Metric

from llmfoundry.models.mpt import ComposerMPTCausalLM
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
from llmfoundry.utils.warnings import experimental_class

__all__ = [
Expand All @@ -33,7 +34,7 @@ def __init__(
self,
log_batch_interval: int,
compute_batch_interval: int,
ignore_index: int = -100,
ignore_index: int = CROSS_ENTROPY_IGNORE_INDEX,
):
if compute_batch_interval > log_batch_interval:
raise ValueError(
Expand Down Expand Up @@ -69,7 +70,7 @@ def after_backward(self, state: State, logger: Logger) -> None:
labels = state.batch['labels']
if state.model.shift_labels:
labels[:, :-1] = labels[:, 1:].detach().clone()
labels[:, -1] = -100
labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
seq_parallel_world_size = getattr(
state.model.model.transformer,
'seq_parallel_world_size',
Expand Down
24 changes: 13 additions & 11 deletions llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

log = logging.getLogger(__name__)

__all__ = [
'Seq2SeqFinetuningCollator',
]

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

TokenizedExample = dict[str, list[dict[str, list[int]]]]


Expand Down Expand Up @@ -79,7 +78,7 @@ def _sequence_to_labels_none(
cutoff: Optional[int] = None,
) -> list[int]:
del is_last_turn, cutoff # unused
return [_HF_IGNORE_INDEX] * len(sequence)
return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence)


def _sequence_to_labels_last(
Expand All @@ -91,7 +90,7 @@ def _sequence_to_labels_last(
if is_last_turn:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence)


def _sequence_to_labels_cutoff(
Expand All @@ -105,7 +104,7 @@ def _sequence_to_labels_cutoff(
if len(sequence) >= cutoff:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence)


_TARGET_POLICY_LOOKUP = {
Expand Down Expand Up @@ -352,7 +351,8 @@ def _process_and_batch_decoder_only(
labels = labels[:max_seq_len]

# Check to make sure there are still loss-generating tokens. Error if not.
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
if len([l for l in labels if l != CROSS_ENTROPY_IGNORE_INDEX
],) == 0:
raise ValueError(
f'Truncating to max_seq_len={max_seq_len} has removed all loss-generating tokens. ' +\
f'Pre-truncation sequence length was {orig_size}. ' +\
Expand All @@ -375,7 +375,7 @@ def _process_and_batch_decoder_only(
# Annoyingly, we need to pad everything but input_ids
# and attention_mask ourselves
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - n_total)
i_pad = [CROSS_ENTROPY_IGNORE_INDEX] * (max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
else:
Expand Down Expand Up @@ -444,7 +444,9 @@ def _process_and_batch_encoder_decoder(
for context, target in contexts_and_targets:
# We need to pad labels ourselves. Because HF.
if len(target) < max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - len(target))
i_pad = [
CROSS_ENTROPY_IGNORE_INDEX,
] * (max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
Expand Down Expand Up @@ -491,12 +493,12 @@ def _process_and_batch_encoder_decoder(
],
dim=1)
batch['decoder_input_ids'].masked_fill_(
batch['decoder_input_ids'] == _HF_IGNORE_INDEX,
batch['decoder_input_ids'] == CROSS_ENTROPY_IGNORE_INDEX,
self.tokenizer.pad_token_id,
)
batch['decoder_attention_mask'] = torch.not_equal(
batch['labels'],
_HF_IGNORE_INDEX,
CROSS_ENTROPY_IGNORE_INDEX,
)

# This logic prevents trimming on at least the first batch
Expand Down
16 changes: 8 additions & 8 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import build_streams
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
from llmfoundry.utils.exceptions import (
FinetuningFileNotFoundError,
MissingHuggingFaceURLSplitError,
Expand All @@ -40,9 +41,6 @@
'build_finetuning_dataloader',
]

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

# Extra keys present in the dataset config dictionary beyond the constructor keys
_ALLOWED_DATASET_KEYS = {
'shuffle',
Expand Down Expand Up @@ -786,7 +784,7 @@ def build_collate_fn(
)
context = torch.logical_and(
batch['attention_mask'][j] == 1,
batch['labels'][j] == _HF_IGNORE_INDEX,
batch['labels'][j] == CROSS_ENTROPY_IGNORE_INDEX,
)
print(
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
Expand All @@ -804,7 +802,8 @@ def build_collate_fn(
j,
torch.logical_and(
is_subseq,
batch['labels'][j] != _HF_IGNORE_INDEX,
batch['labels'][j] !=
CROSS_ENTROPY_IGNORE_INDEX,
)],
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
Expand All @@ -822,7 +821,7 @@ def build_collate_fn(
)
context = torch.logical_and(
batch['attention_mask'][j] == 1,
batch['labels'][j] == _HF_IGNORE_INDEX,
batch['labels'][j] == CROSS_ENTROPY_IGNORE_INDEX,
)
print(
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
Expand All @@ -835,8 +834,9 @@ def build_collate_fn(
print(
'\033[91m{}\033[00m\n'.format('TARGET: '),
tokenizer.decode(
batch['input_ids'][
j, batch['labels'][j] != _HF_IGNORE_INDEX],
batch['input_ids']
[j,
batch['labels'][j] != CROSS_ENTROPY_IGNORE_INDEX],
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
),
Expand Down
5 changes: 3 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
stream_remote_local_validate,
)
from llmfoundry.data.finetuning.collator import (
_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
stitch_turns_encoder_decoder,
)
from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
# yapf: disable
from llmfoundry.utils.exceptions import (
ALLOWED_MESSAGES_KEYS,
Expand Down Expand Up @@ -501,7 +501,8 @@ def is_valid_ift_example(
if len(input_ids) == 0:
return False

if len([label for label in labels if label != _HF_IGNORE_INDEX]) == 0:
if len([label for label in labels if label != CROSS_ENTROPY_IGNORE_INDEX
],) == 0:
return False

return True
Expand Down
6 changes: 4 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from composer.utils import dist
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

log = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -152,7 +154,7 @@ def _convert_to_batch(

pad_vals = {
'input_ids': self.pad_token_id,
'labels': -100,
'labels': CROSS_ENTROPY_IGNORE_INDEX,
'attention_mask': 0,
'sequence_id': -1,
}
Expand Down Expand Up @@ -317,7 +319,7 @@ def _combine_in_place(
if 'labels' in add_on:
# Prevents the last token in example from being trained to
# predict the first token in add_on, which would make no sense.
add_on['labels'][0] = -100
add_on['labels'][0] = CROSS_ENTROPY_IGNORE_INDEX

for k in example.keys():
if k == 'sequence_id':
Expand Down
24 changes: 22 additions & 2 deletions llmfoundry/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from llmfoundry.data.finetuning.dataloader import build_collate_fn
from llmfoundry.data.packing import BinPackCollator
from llmfoundry.data.text_data import ConcatenatedSequenceCollatorWrapper
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +84,7 @@ def get_data_spec(

def get_tokens_per_batch_func(
decoder_only: bool = True,
) -> Callable[[Batch], int]:
) -> Callable[[Batch], Union[int, dict[str, int]]]:
"""Returns a callable that counts the number of tokens in a batch.
Args:
Expand All @@ -95,7 +96,7 @@ def get_tokens_per_batch_func(
Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
"""

def get_num_tokens_in_batch(batch: Batch) -> int:
def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:
if not isinstance(batch, Mapping) or (
'attention_mask' not in batch and 'input_ids' not in batch
):
Expand All @@ -114,13 +115,32 @@ def get_num_tokens_in_batch(batch: Batch) -> int:
else:
input_ids_tokens = batch['input_ids'].numel()

loss_generating_tokens = None
if 'labels' in batch:
loss_generating_tokens = int(
torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(),
)

# Subtract one for each example in the batch that starts with a non -100,
# because those will be shifted off
loss_generating_tokens -= int(
torch.sum(
batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX,
).item(),
)

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_attention_mask']).item(),
)

if loss_generating_tokens is not None:
return {
'total': input_ids_tokens + decoder_input_ids_tokens,
'loss_generating': loss_generating_tokens,
}
return input_ids_tokens + decoder_input_ids_tokens

return get_num_tokens_in_batch
Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@

__all__ = ['BaseHuggingFaceModel']

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

log = logging.getLogger(__name__)


Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from llmfoundry.eval.metrics import InContextLearningMetric
from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

__all__ = ['InferenceAPIEvalWrapper']

Expand Down Expand Up @@ -92,7 +93,7 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
batch = self.rebatch(batch)
self.labels = batch.pop('labels')
self.labels[:, :-1] = self.labels[:, 1:].clone()
self.labels[:, -1] = -100
self.labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
if isinstance(
metric,
InContextLearningMetric,
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@

log = logging.getLogger(__name__)

CROSS_ENTROPY_IGNORE_INDEX = -100


class InvalidConfigAccessError(KeyError):
pass
Expand Down Expand Up @@ -1181,7 +1183,7 @@ def forward(
loss = None
if labels is not None:
_labels = torch.roll(labels, shifts=-1)
_labels[:, -1] = -100
_labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
_labels.to(logits.device).view(-1),
Expand Down Expand Up @@ -1331,7 +1333,7 @@ def _reorder_cache(

def get_targets(labels: torch.Tensor) -> torch.Tensor:
targets = torch.roll(labels, shifts=-1)
targets[:, -1] = -100
targets[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
return targets


Expand Down Expand Up @@ -1410,7 +1412,7 @@ def __init__(
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(
ignore_index=-100,
ignore_index=CROSS_ENTROPY_IGNORE_INDEX,
reduction='none',
)
except:
Expand All @@ -1423,7 +1425,7 @@ def __init__(
)
elif loss_fn_config == 'torch_crossentropy':
self.loss_fn = nn.CrossEntropyLoss(
ignore_index=-100,
ignore_index=CROSS_ENTROPY_IGNORE_INDEX,
reduction='none',
)
else:
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
process_init_device,
update_batch_size_info,
)
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
merge_shard_groups,
Expand Down Expand Up @@ -111,4 +112,5 @@
'ExperimentalWarning',
'experimental_function',
'experimental_class',
'CROSS_ENTROPY_IGNORE_INDEX',
]
4 changes: 4 additions & 0 deletions llmfoundry/utils/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

CROSS_ENTROPY_IGNORE_INDEX = -100
Loading

0 comments on commit 874c30a

Please sign in to comment.