diff --git a/llmfoundry/metrics/__init__.py b/llmfoundry/metrics/__init__.py new file mode 100644 index 0000000000..db4beba80e --- /dev/null +++ b/llmfoundry/metrics/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.metrics.token_acc import TokenAccuracy + +__all__ = [ + 'TokenAccuracy', +] diff --git a/llmfoundry/metrics/token_acc.py b/llmfoundry/metrics/token_acc.py new file mode 100644 index 0000000000..1cdcffe1db --- /dev/null +++ b/llmfoundry/metrics/token_acc.py @@ -0,0 +1,65 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torchmetrics import Metric + +__all__ = [ + 'TokenAccuracy', +] + + +class TokenAccuracy(Metric): + """Torchmetric to compute token-level accuracy for language modeling. + + Adds metric state variables: + correct_tokens (float): The number of correct token predictions. + total_tokens (float): The total number of tokens predicted. + + Args: + ignore_index (int, optional): The index of tokens to ignore, typically for padding. Default: -100. + dist_sync_on_step (bool, optional): Synchronize metric state across processes at + each forward() before returning the value at the step. Default: False. + """ + + # Ensures torchmetrics calls update only once + full_state_update = False + + def __init__(self, + ignore_index: int = -100, + dist_sync_on_step: bool = False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.ignore_index = ignore_index + self.add_state('correct_tokens', + default=torch.tensor(0), + dist_reduce_fx='sum') + self.add_state('total_tokens', + default=torch.tensor(0), + dist_reduce_fx='sum') + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """Updates the internal state with results from a new batch. + + Args: + preds (~torch.Tensor): The predictions from the model, a Tensor of logits. + target (~torch.Tensor): A Tensor of ground-truth token values. + """ + # Convert logits to predicted token indices + preds = torch.argmax(preds, dim=-1) + + # Create mask for non-ignored tokens + mask = (target != self.ignore_index) + masked_target = target[mask] + masked_preds = preds[mask] + + # Update correct and total counts + self.correct_tokens += torch.sum(masked_preds == masked_target) + self.total_tokens += masked_target.numel() + + def compute(self) -> torch.Tensor: + """Aggregate the state over all processes to compute the metric. + + Returns: + The mean accuracy across all tokens as a :class:`~torch.Tensor`. + """ + return self.correct_tokens.float() / self.total_tokens diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 9f1136e597..dd766c99af 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -22,6 +22,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase) +from llmfoundry.metrics import TokenAccuracy from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers.attention import is_flash_v2_installed @@ -110,10 +111,15 @@ def __init__(self, om_model_config: DictConfig, ) # Set up training and eval metrics - train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] + train_metrics = [ + LanguageCrossEntropy(), + LanguagePerplexity(), + TokenAccuracy() + ] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), + TokenAccuracy(), InContextLearningLMAccuracy(), InContextLearningMultipleChoiceAccuracy(), InContextLearningQAAccuracy(), diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e9ad8054e2..93d8cbef74 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -26,6 +26,7 @@ from composer.models import HuggingFaceModel from composer.utils import dist +from llmfoundry.metrics import TokenAccuracy from llmfoundry.models.layers.attention import (is_flash_v1_installed, is_flash_v2_installed) from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY @@ -1045,11 +1046,15 @@ def __init__( model = MPTForCausalLM(hf_config) use_train_metrics = om_model_config.get('use_train_metrics', True) - train_metrics = [LanguageCrossEntropy(), - LanguagePerplexity()] if use_train_metrics else [] + train_metrics = [ + LanguageCrossEntropy(), + LanguagePerplexity(), + TokenAccuracy() + ] if use_train_metrics else [] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), + TokenAccuracy(), InContextLearningLMAccuracy(), InContextLearningMultipleChoiceAccuracy(), InContextLearningQAAccuracy(), diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 07a9c3900e..9fdc20c0d6 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -54,10 +54,26 @@ def _flatten_import( return False +def _remove_import( + node: ast.ImportFrom, + remove_imports_prefix: Sequence[str], +) -> bool: + """Returns True if import should be removed. + + Checks whether the node starts the same as any of the imports in + remove_imports_prefix. + """ + for import_prefix in remove_imports_prefix: + if node.module is not None and node.module.startswith(import_prefix): + return True + return False + + def process_file( file_path: str, folder_path: str, flatten_imports_prefix: Sequence[str], + remove_imports_prefix: Sequence[str], ) -> list[str]: with open(file_path, 'r', encoding='utf-8') as f: source = f.read() @@ -70,19 +86,21 @@ def process_file( new_files_to_process = [] nodes_to_remove = [] for node in ast.walk(tree): - # Convert any llmfoundry imports into relative imports - if (isinstance(node, ast.ImportFrom) and node.module is not None and - _flatten_import(node, flatten_imports_prefix)): + # Remove any imports matching the remove_imports_prefix + if isinstance( + node, + ast.ImportFrom) and node.module is not None and _remove_import( + node, remove_imports_prefix): + nodes_to_remove.append(node) + # Convert any (remaining) imports matching the flatten_imports_prefix + # to relative imports + elif (isinstance(node, ast.ImportFrom) and node.module is not None and + _flatten_import(node, flatten_imports_prefix)): module_path = find_module_file(node.module) node.module = convert_to_relative_import(node.module, parent_module_name) # Recursively process any llmfoundry files new_files_to_process.append(module_path) - # Remove any imports from composer or omegaconf - elif isinstance(node, ast.ImportFrom) and node.module is not None and ( - node.module.startswith('composer') or - node.module.startswith('omegaconf')): - nodes_to_remove.append(node) # Remove the Composer* class elif (isinstance(node, ast.ClassDef) and node.name.startswith('Composer')): @@ -110,9 +128,19 @@ def process_file( def edit_files_for_hf_compatibility( - folder: str, - flatten_imports_prefix: Sequence[str] = ('llmfoundry',), + folder: str, + flatten_imports_prefix: Sequence[str] = ('llmfoundry',), + remove_imports_prefix: Sequence[str] = ('composer', 'omegaconf', + 'llmfoundry.metrics'), ) -> None: + """Edit files to be compatible with Hugging Face Hub. + + Args: + folder (str): The folder to process. + flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',). + remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening. + Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics'). + """ files_to_process = [ os.path.join(folder, filename) for filename in os.listdir(folder) @@ -123,7 +151,12 @@ def edit_files_for_hf_compatibility( while len(files_to_process) > 0: to_process = files_to_process.pop() if os.path.isfile(to_process) and to_process.endswith('.py'): - to_add = process_file(to_process, folder, flatten_imports_prefix) + to_add = process_file( + to_process, + folder, + flatten_imports_prefix, + remove_imports_prefix, + ) for file in to_add: if file not in files_processed_and_queued: files_to_process.append(file) diff --git a/tests/metrics/test_token_acc.py b/tests/metrics/test_token_acc.py new file mode 100644 index 0000000000..30fd5c07f6 --- /dev/null +++ b/tests/metrics/test_token_acc.py @@ -0,0 +1,32 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from llmfoundry.metrics import TokenAccuracy + + +@pytest.mark.parametrize('ignore_index', [-100, -200]) +@pytest.mark.parametrize('vocab_size', [100]) +def test_token_accuracy(ignore_index: int, vocab_size: int): + batch_size = int(1e6) + torchmetrics_token_acc = TokenAccuracy(ignore_index=ignore_index) + generated_preds = torch.rand((batch_size, vocab_size)) + true_labels = torch.randint(low=0, high=vocab_size - 1, size=(batch_size,)) + + # Randomly insert ignore_index into the labels + labels_mask = torch.rand((batch_size,)) + labels_mask[labels_mask > 0.8] = 1 + labels_mask[labels_mask <= 0.8] = 0 + labels_mask = labels_mask.bool() + true_labels[labels_mask] = ignore_index + + true_labels = true_labels.float() + generated_preds = generated_preds.float() + + torchmetrics_token_acc.update(generated_preds, true_labels) + final_acc = torchmetrics_token_acc.compute() + + expected_random_acc_tensor = torch.tensor(1.0 / vocab_size) + torch.testing.assert_close(final_acc, expected_random_acc_tensor) diff --git a/tests/utils/test_huggingface_hub_utils.py b/tests/utils/test_huggingface_hub_utils.py index 5effb3a771..39dbf2781d 100644 --- a/tests/utils/test_huggingface_hub_utils.py +++ b/tests/utils/test_huggingface_hub_utils.py @@ -3,7 +3,10 @@ import ast -from llmfoundry.utils.huggingface_hub_utils import _flatten_import +import pytest + +from llmfoundry.utils.huggingface_hub_utils import (_flatten_import, + _remove_import) def test_flatten_import_true(): @@ -14,3 +17,24 @@ def test_flatten_import_true(): def test_flatten_import_false(): node = ast.ImportFrom('y', ['x', 'y', 'z']) assert not _flatten_import(node, ('x', 'z')) + + +@pytest.mark.parametrize('prefix_to_remove,expected_imports_remaining', + [('llmfoundry', 1), ('llmfoundry.utils', 2)]) +def test_remove_imports(prefix_to_remove: str, expected_imports_remaining: int): + source_code = """ +from llmfoundry import a +from llmfoundry.utils import b +from other_package import c +""" + + tree = ast.parse(source_code) + assert len(tree.body) == 3 + + imports_kept = 0 + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and not _remove_import( + node, [prefix_to_remove]): + imports_kept += 1 + + assert imports_kept == expected_imports_remaining