From 874c30aadadb9a1c477a476a3ea2784f5d235eea Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sat, 26 Oct 2024 18:56:40 -0700 Subject: [PATCH] Add loss generating token counts (#1610) --- .../callbacks/loss_perp_v_len_callback.py | 5 ++-- llmfoundry/data/finetuning/collator.py | 24 ++++++++++--------- llmfoundry/data/finetuning/dataloader.py | 16 ++++++------- llmfoundry/data/finetuning/tasks.py | 5 ++-- llmfoundry/data/packing.py | 6 +++-- llmfoundry/data/utils.py | 24 +++++++++++++++++-- llmfoundry/models/hf/hf_base.py | 3 --- .../models/inference_api_wrapper/interface.py | 3 ++- llmfoundry/models/mpt/modeling_mpt.py | 10 ++++---- llmfoundry/utils/__init__.py | 2 ++ llmfoundry/utils/consts.py | 4 ++++ .../test_loss_perp_v_len_callback.py | 17 ++++++++----- tests/data/test_dataloader.py | 23 ++++++++++++++---- tests/models/test_model.py | 14 +++++++---- 14 files changed, 106 insertions(+), 50 deletions(-) create mode 100644 llmfoundry/utils/consts.py diff --git a/llmfoundry/callbacks/loss_perp_v_len_callback.py b/llmfoundry/callbacks/loss_perp_v_len_callback.py index b402972198..31fc20c775 100644 --- a/llmfoundry/callbacks/loss_perp_v_len_callback.py +++ b/llmfoundry/callbacks/loss_perp_v_len_callback.py @@ -10,6 +10,7 @@ from torchmetrics import Metric from llmfoundry.models.mpt import ComposerMPTCausalLM +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX from llmfoundry.utils.warnings import experimental_class __all__ = [ @@ -33,7 +34,7 @@ def __init__( self, log_batch_interval: int, compute_batch_interval: int, - ignore_index: int = -100, + ignore_index: int = CROSS_ENTROPY_IGNORE_INDEX, ): if compute_batch_interval > log_batch_interval: raise ValueError( @@ -69,7 +70,7 @@ def after_backward(self, state: State, logger: Logger) -> None: labels = state.batch['labels'] if state.model.shift_labels: labels[:, :-1] = labels[:, 1:].detach().clone() - labels[:, -1] = -100 + labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX seq_parallel_world_size = getattr( state.model.model.transformer, 'seq_parallel_world_size', diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py index b24afd163e..4d9284682e 100644 --- a/llmfoundry/data/finetuning/collator.py +++ b/llmfoundry/data/finetuning/collator.py @@ -8,15 +8,14 @@ import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX + log = logging.getLogger(__name__) __all__ = [ 'Seq2SeqFinetuningCollator', ] -# HuggingFace hardcodes the ignore index to -100 -_HF_IGNORE_INDEX = -100 - TokenizedExample = dict[str, list[dict[str, list[int]]]] @@ -79,7 +78,7 @@ def _sequence_to_labels_none( cutoff: Optional[int] = None, ) -> list[int]: del is_last_turn, cutoff # unused - return [_HF_IGNORE_INDEX] * len(sequence) + return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence) def _sequence_to_labels_last( @@ -91,7 +90,7 @@ def _sequence_to_labels_last( if is_last_turn: return sequence else: - return [_HF_IGNORE_INDEX] * len(sequence) + return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence) def _sequence_to_labels_cutoff( @@ -105,7 +104,7 @@ def _sequence_to_labels_cutoff( if len(sequence) >= cutoff: return sequence else: - return [_HF_IGNORE_INDEX] * len(sequence) + return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence) _TARGET_POLICY_LOOKUP = { @@ -352,7 +351,8 @@ def _process_and_batch_decoder_only( labels = labels[:max_seq_len] # Check to make sure there are still loss-generating tokens. Error if not. - if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0: + if len([l for l in labels if l != CROSS_ENTROPY_IGNORE_INDEX + ],) == 0: raise ValueError( f'Truncating to max_seq_len={max_seq_len} has removed all loss-generating tokens. ' +\ f'Pre-truncation sequence length was {orig_size}. ' +\ @@ -375,7 +375,7 @@ def _process_and_batch_decoder_only( # Annoyingly, we need to pad everything but input_ids # and attention_mask ourselves n_total = len(input_ids) - i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - n_total) + i_pad = [CROSS_ENTROPY_IGNORE_INDEX] * (max_seq_len - n_total) if self.tokenizer.padding_side == 'left': labels = i_pad + labels else: @@ -444,7 +444,9 @@ def _process_and_batch_encoder_decoder( for context, target in contexts_and_targets: # We need to pad labels ourselves. Because HF. if len(target) < max_seq_len: - i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - len(target)) + i_pad = [ + CROSS_ENTROPY_IGNORE_INDEX, + ] * (max_seq_len - len(target)) target = target + i_pad else: if not self._warned_target: @@ -491,12 +493,12 @@ def _process_and_batch_encoder_decoder( ], dim=1) batch['decoder_input_ids'].masked_fill_( - batch['decoder_input_ids'] == _HF_IGNORE_INDEX, + batch['decoder_input_ids'] == CROSS_ENTROPY_IGNORE_INDEX, self.tokenizer.pad_token_id, ) batch['decoder_attention_mask'] = torch.not_equal( batch['labels'], - _HF_IGNORE_INDEX, + CROSS_ENTROPY_IGNORE_INDEX, ) # This logic prevents trimming on at least the first batch diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index e73213f74a..661729ff8a 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -26,6 +26,7 @@ from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import build_streams from llmfoundry.utils.config_utils import to_dict_container +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX from llmfoundry.utils.exceptions import ( FinetuningFileNotFoundError, MissingHuggingFaceURLSplitError, @@ -40,9 +41,6 @@ 'build_finetuning_dataloader', ] -# HuggingFace hardcodes the ignore index to -100 -_HF_IGNORE_INDEX = -100 - # Extra keys present in the dataset config dictionary beyond the constructor keys _ALLOWED_DATASET_KEYS = { 'shuffle', @@ -786,7 +784,7 @@ def build_collate_fn( ) context = torch.logical_and( batch['attention_mask'][j] == 1, - batch['labels'][j] == _HF_IGNORE_INDEX, + batch['labels'][j] == CROSS_ENTROPY_IGNORE_INDEX, ) print( '\033[92m{}\033[00m\n'.format('CONTEXT: '), @@ -804,7 +802,8 @@ def build_collate_fn( j, torch.logical_and( is_subseq, - batch['labels'][j] != _HF_IGNORE_INDEX, + batch['labels'][j] != + CROSS_ENTROPY_IGNORE_INDEX, )], skip_special_tokens=False, clean_up_tokenization_spaces=True, @@ -822,7 +821,7 @@ def build_collate_fn( ) context = torch.logical_and( batch['attention_mask'][j] == 1, - batch['labels'][j] == _HF_IGNORE_INDEX, + batch['labels'][j] == CROSS_ENTROPY_IGNORE_INDEX, ) print( '\033[92m{}\033[00m\n'.format('CONTEXT: '), @@ -835,8 +834,9 @@ def build_collate_fn( print( '\033[91m{}\033[00m\n'.format('TARGET: '), tokenizer.decode( - batch['input_ids'][ - j, batch['labels'][j] != _HF_IGNORE_INDEX], + batch['input_ids'] + [j, + batch['labels'][j] != CROSS_ENTROPY_IGNORE_INDEX], skip_special_tokens=False, clean_up_tokenization_spaces=True, ), diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 2030a87fe7..b83aee3aa6 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -62,11 +62,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: stream_remote_local_validate, ) from llmfoundry.data.finetuning.collator import ( - _HF_IGNORE_INDEX, stitch_turns_decoder_only, stitch_turns_encoder_decoder, ) from llmfoundry.tokenizers import get_date_string +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX # yapf: disable from llmfoundry.utils.exceptions import ( ALLOWED_MESSAGES_KEYS, @@ -501,7 +501,8 @@ def is_valid_ift_example( if len(input_ids) == 0: return False - if len([label for label in labels if label != _HF_IGNORE_INDEX]) == 0: + if len([label for label in labels if label != CROSS_ENTROPY_IGNORE_INDEX + ],) == 0: return False return True diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 2f18bda7fc..e3c19cc91c 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -10,6 +10,8 @@ from composer.utils import dist from transformers import PreTrainedTokenizerBase +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX + log = logging.getLogger(__name__) __all__ = [ @@ -152,7 +154,7 @@ def _convert_to_batch( pad_vals = { 'input_ids': self.pad_token_id, - 'labels': -100, + 'labels': CROSS_ENTROPY_IGNORE_INDEX, 'attention_mask': 0, 'sequence_id': -1, } @@ -317,7 +319,7 @@ def _combine_in_place( if 'labels' in add_on: # Prevents the last token in example from being trained to # predict the first token in add_on, which would make no sense. - add_on['labels'][0] = -100 + add_on['labels'][0] = CROSS_ENTROPY_IGNORE_INDEX for k in example.keys(): if k == 'sequence_id': diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index e8ed1a947d..21c28d9183 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -15,6 +15,7 @@ from llmfoundry.data.finetuning.dataloader import build_collate_fn from llmfoundry.data.packing import BinPackCollator from llmfoundry.data.text_data import ConcatenatedSequenceCollatorWrapper +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX log = logging.getLogger(__name__) @@ -83,7 +84,7 @@ def get_data_spec( def get_tokens_per_batch_func( decoder_only: bool = True, -) -> Callable[[Batch], int]: +) -> Callable[[Batch], Union[int, dict[str, int]]]: """Returns a callable that counts the number of tokens in a batch. Args: @@ -95,7 +96,7 @@ def get_tokens_per_batch_func( Callable[[Batch], int]: A callable that counts the number of tokens in a batch. """ - def get_num_tokens_in_batch(batch: Batch) -> int: + def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: if not isinstance(batch, Mapping) or ( 'attention_mask' not in batch and 'input_ids' not in batch ): @@ -114,6 +115,20 @@ def get_num_tokens_in_batch(batch: Batch) -> int: else: input_ids_tokens = batch['input_ids'].numel() + loss_generating_tokens = None + if 'labels' in batch: + loss_generating_tokens = int( + torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(), + ) + + # Subtract one for each example in the batch that starts with a non -100, + # because those will be shifted off + loss_generating_tokens -= int( + torch.sum( + batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX, + ).item(), + ) + # For encoder decoder models only decoder_input_ids_tokens = 0 if not decoder_only: @@ -121,6 +136,11 @@ def get_num_tokens_in_batch(batch: Batch) -> int: torch.sum(batch['decoder_attention_mask']).item(), ) + if loss_generating_tokens is not None: + return { + 'total': input_ids_tokens + decoder_input_ids_tokens, + 'loss_generating': loss_generating_tokens, + } return input_ids_tokens + decoder_input_ids_tokens return get_num_tokens_in_batch diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index 2ec9bbaa98..10b49d1ec8 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -38,9 +38,6 @@ __all__ = ['BaseHuggingFaceModel'] -# HuggingFace hardcodes the ignore index to -100 -_HF_IGNORE_INDEX = -100 - log = logging.getLogger(__name__) diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py index 6d231441ae..ca541d536d 100644 --- a/llmfoundry/models/inference_api_wrapper/interface.py +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -12,6 +12,7 @@ from llmfoundry.eval.metrics import InContextLearningMetric from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX __all__ = ['InferenceAPIEvalWrapper'] @@ -92,7 +93,7 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None: batch = self.rebatch(batch) self.labels = batch.pop('labels') self.labels[:, :-1] = self.labels[:, 1:].clone() - self.labels[:, -1] = -100 + self.labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX if isinstance( metric, InContextLearningMetric, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9212f5594d..da576b29e1 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -81,6 +81,8 @@ log = logging.getLogger(__name__) +CROSS_ENTROPY_IGNORE_INDEX = -100 + class InvalidConfigAccessError(KeyError): pass @@ -1181,7 +1183,7 @@ def forward( loss = None if labels is not None: _labels = torch.roll(labels, shifts=-1) - _labels[:, -1] = -100 + _labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX loss = F.cross_entropy( logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1), @@ -1331,7 +1333,7 @@ def _reorder_cache( def get_targets(labels: torch.Tensor) -> torch.Tensor: targets = torch.roll(labels, shifts=-1) - targets[:, -1] = -100 + targets[:, -1] = CROSS_ENTROPY_IGNORE_INDEX return targets @@ -1410,7 +1412,7 @@ def __init__( CrossEntropyLoss as FusedCrossEntropyLoss self.loss_fn = FusedCrossEntropyLoss( - ignore_index=-100, + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, reduction='none', ) except: @@ -1423,7 +1425,7 @@ def __init__( ) elif loss_fn_config == 'torch_crossentropy': self.loss_fn = nn.CrossEntropyLoss( - ignore_index=-100, + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, reduction='none', ) else: diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 87a08a999d..683d719c90 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -29,6 +29,7 @@ process_init_device, update_batch_size_info, ) +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX from llmfoundry.utils.data_prep_utils import ( DownloadingIterable, merge_shard_groups, @@ -111,4 +112,5 @@ 'ExperimentalWarning', 'experimental_function', 'experimental_class', + 'CROSS_ENTROPY_IGNORE_INDEX', ] diff --git a/llmfoundry/utils/consts.py b/llmfoundry/utils/consts.py new file mode 100644 index 0000000000..2cc7a1a175 --- /dev/null +++ b/llmfoundry/utils/consts.py @@ -0,0 +1,4 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +CROSS_ENTROPY_IGNORE_INDEX = -100 diff --git a/tests/callbacks/test_loss_perp_v_len_callback.py b/tests/callbacks/test_loss_perp_v_len_callback.py index 4c487560d2..7ab1a7c5ba 100644 --- a/tests/callbacks/test_loss_perp_v_len_callback.py +++ b/tests/callbacks/test_loss_perp_v_len_callback.py @@ -21,6 +21,7 @@ build_text_dataloader, ) from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX from llmfoundry.utils.registry_utils import construct_from_registry @@ -201,7 +202,7 @@ def mock_loss_fn(input_logits: Any, input_labels: Any): del input_logits, input_labels return loss - loss_v_len_metric = LossPerpVLen(ignore_index=-100) + loss_v_len_metric = LossPerpVLen(ignore_index=CROSS_ENTROPY_IGNORE_INDEX) loss_v_len_metric.update( labels=labels, logits=logits, @@ -288,7 +289,7 @@ def test_valid_labels(): labels = torch.tensor([[ 1, ] * (seq_len - ignore_labels_len) + [ - -100, + CROSS_ENTROPY_IGNORE_INDEX, ] * ignore_labels_len] * batch_size) logits = torch.tensor([[ 1, @@ -302,7 +303,7 @@ def mock_loss_fn(input_logits: Any, input_labels: Any): del input_logits, input_labels return loss - loss_v_len_metric = LossPerpVLen(ignore_index=-100) + loss_v_len_metric = LossPerpVLen(ignore_index=CROSS_ENTROPY_IGNORE_INDEX) loss_v_len_metric.update( labels=labels, logits=logits, @@ -349,7 +350,9 @@ def mock_loss_fn_no_pad(input_logits: Any, input_labels: Any): del input_logits, input_labels return loss_no_pad - loss_v_len_metric_no_pad = LossPerpVLen(ignore_index=-100) + loss_v_len_metric_no_pad = LossPerpVLen( + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, + ) loss_v_len_metric_no_pad.update( labels=labels_no_pad, logits=logits_no_pad, @@ -362,7 +365,7 @@ def mock_loss_fn_no_pad(input_logits: Any, input_labels: Any): labels_pad = torch.tensor([[ 1, ] * seq_len + [ - -100, + CROSS_ENTROPY_IGNORE_INDEX, ] * pad_len] * batch_size) logits_pad = torch.tensor([[ 1, @@ -388,7 +391,9 @@ def mock_loss_fn_pad(input_logits: Any, input_labels: Any): del input_logits, input_labels return loss_pad - loss_v_len_metric_pad = LossPerpVLen(ignore_index=-100) + loss_v_len_metric_pad = LossPerpVLen( + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, + ) loss_v_len_metric_pad.update( labels=labels_pad, logits=logits_pad, diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 2608ccd091..682d25f7f1 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -5,6 +5,7 @@ import pathlib import random import shutil +from collections import Counter from contextlib import nullcontext as does_not_raise from pathlib import Path from typing import Any, Callable, ContextManager, Literal, Optional, Union @@ -26,7 +27,6 @@ get_columns_and_format from llmfoundry.data import build_dataloader, build_finetuning_dataloader from llmfoundry.data.finetuning.collator import ( - _HF_IGNORE_INDEX, validate_target_settings, ) from llmfoundry.data.finetuning.tasks import ( @@ -43,6 +43,7 @@ from llmfoundry.data.utils import get_tokens_per_batch_func from llmfoundry.utils.builders import build_tokenizer from llmfoundry.utils.config_utils import to_dict_container +from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX # yapf: disable from llmfoundry.utils.exceptions import ( ConsecutiveRepeatedChatRolesError, @@ -269,7 +270,7 @@ def test_correct_padding( torch.ones_like(batch['input_ids'], dtype=torch.bool), ) a = attention_mask == 0 - b = batch['labels'] == -100 + b = batch['labels'] == CROSS_ENTROPY_IGNORE_INDEX assert torch.equal(a, b) @@ -1104,7 +1105,7 @@ def pad_preprocessing_function( # type: ignore labels = batch['labels'][ 0, torch. - logical_and(is_subseq, batch['labels'][0] != _HF_IGNORE_INDEX)] + logical_and(is_subseq, batch['labels'][0] != CROSS_ENTROPY_IGNORE_INDEX)] assert all(labels[:-1] == tokenizer.pad_token_id) if i >= 20: break @@ -1185,12 +1186,15 @@ def test_token_counting_func_dataloader_setting( batch_strings = [] expected_token_count = 0 + expected_loss_generating_token_count = 0 + sample_lengths = [] for _ in range(batch_size): # Get randomly different lengths if we are going to add padding sample_length = random.randint(1, model_max_length // 4) if ( pad_token_id is not None and not tensor_input ) else model_max_length // 4 batch_strings.append(' '.join(['hello'] * sample_length)) + sample_lengths.append(sample_length) expected_token_count += sample_length batch_tokenized = [ @@ -1207,8 +1211,15 @@ def test_token_counting_func_dataloader_setting( for b in batch_tokenized: b['labels'] = b['input_ids'].copy() # type: ignore batch_tokenized = [{'turns': [b]} for b in batch_tokenized] + expected_loss_generating_token_count = expected_token_count expected_token_count *= 2 expected_token_count += 1 * batch_size # for the eos token + expected_loss_generating_token_count += 1 * batch_size # for the eos token + else: + expected_loss_generating_token_count = expected_token_count + + number_of_shifted_off_labels = Counter(sample_lengths)[max(sample_lengths)] + expected_loss_generating_token_count -= 1 * number_of_shifted_off_labels # because the labels will be shifted common_args = { 'drop_last': False, @@ -1310,9 +1321,11 @@ def build_from_hf( raise NotImplementedError() batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore - actual_token_count = dl.get_num_tokens_in_batch(batch_collated) + actual_total_token_count = dl.get_num_tokens_in_batch(batch_collated, token_type='total') + actual_loss_generating_token_count = dl.get_num_tokens_in_batch(batch_collated, token_type='loss_generating') - assert actual_token_count == expected_token_count + assert actual_total_token_count == expected_token_count + assert actual_loss_generating_token_count == expected_loss_generating_token_count def test_build_unknown_dataloader(): diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 43067f5e47..a7769c237d 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -49,7 +49,10 @@ ) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel -from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry +from llmfoundry.models.mpt.modeling_mpt import ( + CROSS_ENTROPY_IGNORE_INDEX, + LlamaRotaryEmbeddingFoundry, +) from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -625,7 +628,10 @@ def test_loss_fn(): model_2.to(test_cfg.device) assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) - model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100, reduction='none') + model_2.loss_fn = FusedCrossEntropyLoss( + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, + reduction='none', + ) optimizer_1 = DecoupledAdamW( model_1.parameters(), @@ -732,13 +738,13 @@ def test_loss_reduction(loss_fn_config: str): if loss_fn_config == 'fused_crossentropy': assert isinstance(model_1.loss_fn, FusedCrossEntropyLoss) model_2.loss_fn = FusedCrossEntropyLoss( - ignore_index=-100, + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, reduction='mean', ) else: assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss) model_2.loss_fn = torch.nn.CrossEntropyLoss( - ignore_index=-100, + ignore_index=CROSS_ENTROPY_IGNORE_INDEX, reduction='mean', )