Skip to content

Commit

Permalink
Token accuracy metrics (#983)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Feb 21, 2024
1 parent 6e3842b commit 2431730
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 15 deletions.
8 changes: 8 additions & 0 deletions llmfoundry/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.metrics.token_acc import TokenAccuracy

__all__ = [
'TokenAccuracy',
]
65 changes: 65 additions & 0 deletions llmfoundry/metrics/token_acc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import torch
from torchmetrics import Metric

__all__ = [
'TokenAccuracy',
]


class TokenAccuracy(Metric):
"""Torchmetric to compute token-level accuracy for language modeling.
Adds metric state variables:
correct_tokens (float): The number of correct token predictions.
total_tokens (float): The total number of tokens predicted.
Args:
ignore_index (int, optional): The index of tokens to ignore, typically for padding. Default: -100.
dist_sync_on_step (bool, optional): Synchronize metric state across processes at
each forward() before returning the value at the step. Default: False.
"""

# Ensures torchmetrics calls update only once
full_state_update = False

def __init__(self,
ignore_index: int = -100,
dist_sync_on_step: bool = False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.ignore_index = ignore_index
self.add_state('correct_tokens',
default=torch.tensor(0),
dist_reduce_fx='sum')
self.add_state('total_tokens',
default=torch.tensor(0),
dist_reduce_fx='sum')

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""Updates the internal state with results from a new batch.
Args:
preds (~torch.Tensor): The predictions from the model, a Tensor of logits.
target (~torch.Tensor): A Tensor of ground-truth token values.
"""
# Convert logits to predicted token indices
preds = torch.argmax(preds, dim=-1)

# Create mask for non-ignored tokens
mask = (target != self.ignore_index)
masked_target = target[mask]
masked_preds = preds[mask]

# Update correct and total counts
self.correct_tokens += torch.sum(masked_preds == masked_target)
self.total_tokens += masked_target.numel()

def compute(self) -> torch.Tensor:
"""Aggregate the state over all processes to compute the metric.
Returns:
The mean accuracy across all tokens as a :class:`~torch.Tensor`.
"""
return self.correct_tokens.float() / self.total_tokens
8 changes: 7 additions & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)

from llmfoundry.metrics import TokenAccuracy
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.attention import is_flash_v2_installed
Expand Down Expand Up @@ -110,10 +111,15 @@ def __init__(self, om_model_config: DictConfig,
)

# Set up training and eval metrics
train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()]
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy()
]
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
Expand Down
9 changes: 7 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from composer.models import HuggingFaceModel
from composer.utils import dist

from llmfoundry.metrics import TokenAccuracy
from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
Expand Down Expand Up @@ -1045,11 +1046,15 @@ def __init__(
model = MPTForCausalLM(hf_config)

use_train_metrics = om_model_config.get('use_train_metrics', True)
train_metrics = [LanguageCrossEntropy(),
LanguagePerplexity()] if use_train_metrics else []
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy()
] if use_train_metrics else []
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
Expand Down
55 changes: 44 additions & 11 deletions llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,26 @@ def _flatten_import(
return False


def _remove_import(
node: ast.ImportFrom,
remove_imports_prefix: Sequence[str],
) -> bool:
"""Returns True if import should be removed.
Checks whether the node starts the same as any of the imports in
remove_imports_prefix.
"""
for import_prefix in remove_imports_prefix:
if node.module is not None and node.module.startswith(import_prefix):
return True
return False


def process_file(
file_path: str,
folder_path: str,
flatten_imports_prefix: Sequence[str],
remove_imports_prefix: Sequence[str],
) -> list[str]:
with open(file_path, 'r', encoding='utf-8') as f:
source = f.read()
Expand All @@ -70,19 +86,21 @@ def process_file(
new_files_to_process = []
nodes_to_remove = []
for node in ast.walk(tree):
# Convert any llmfoundry imports into relative imports
if (isinstance(node, ast.ImportFrom) and node.module is not None and
_flatten_import(node, flatten_imports_prefix)):
# Remove any imports matching the remove_imports_prefix
if isinstance(
node,
ast.ImportFrom) and node.module is not None and _remove_import(
node, remove_imports_prefix):
nodes_to_remove.append(node)
# Convert any (remaining) imports matching the flatten_imports_prefix
# to relative imports
elif (isinstance(node, ast.ImportFrom) and node.module is not None and
_flatten_import(node, flatten_imports_prefix)):
module_path = find_module_file(node.module)
node.module = convert_to_relative_import(node.module,
parent_module_name)
# Recursively process any llmfoundry files
new_files_to_process.append(module_path)
# Remove any imports from composer or omegaconf
elif isinstance(node, ast.ImportFrom) and node.module is not None and (
node.module.startswith('composer') or
node.module.startswith('omegaconf')):
nodes_to_remove.append(node)
# Remove the Composer* class
elif (isinstance(node, ast.ClassDef) and
node.name.startswith('Composer')):
Expand Down Expand Up @@ -110,9 +128,19 @@ def process_file(


def edit_files_for_hf_compatibility(
folder: str,
flatten_imports_prefix: Sequence[str] = ('llmfoundry',),
folder: str,
flatten_imports_prefix: Sequence[str] = ('llmfoundry',),
remove_imports_prefix: Sequence[str] = ('composer', 'omegaconf',
'llmfoundry.metrics'),
) -> None:
"""Edit files to be compatible with Hugging Face Hub.
Args:
folder (str): The folder to process.
flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',).
remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening.
Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics').
"""
files_to_process = [
os.path.join(folder, filename)
for filename in os.listdir(folder)
Expand All @@ -123,7 +151,12 @@ def edit_files_for_hf_compatibility(
while len(files_to_process) > 0:
to_process = files_to_process.pop()
if os.path.isfile(to_process) and to_process.endswith('.py'):
to_add = process_file(to_process, folder, flatten_imports_prefix)
to_add = process_file(
to_process,
folder,
flatten_imports_prefix,
remove_imports_prefix,
)
for file in to_add:
if file not in files_processed_and_queued:
files_to_process.append(file)
Expand Down
32 changes: 32 additions & 0 deletions tests/metrics/test_token_acc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from llmfoundry.metrics import TokenAccuracy


@pytest.mark.parametrize('ignore_index', [-100, -200])
@pytest.mark.parametrize('vocab_size', [100])
def test_token_accuracy(ignore_index: int, vocab_size: int):
batch_size = int(1e6)
torchmetrics_token_acc = TokenAccuracy(ignore_index=ignore_index)
generated_preds = torch.rand((batch_size, vocab_size))
true_labels = torch.randint(low=0, high=vocab_size - 1, size=(batch_size,))

# Randomly insert ignore_index into the labels
labels_mask = torch.rand((batch_size,))
labels_mask[labels_mask > 0.8] = 1
labels_mask[labels_mask <= 0.8] = 0
labels_mask = labels_mask.bool()
true_labels[labels_mask] = ignore_index

true_labels = true_labels.float()
generated_preds = generated_preds.float()

torchmetrics_token_acc.update(generated_preds, true_labels)
final_acc = torchmetrics_token_acc.compute()

expected_random_acc_tensor = torch.tensor(1.0 / vocab_size)
torch.testing.assert_close(final_acc, expected_random_acc_tensor)
26 changes: 25 additions & 1 deletion tests/utils/test_huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import ast

from llmfoundry.utils.huggingface_hub_utils import _flatten_import
import pytest

from llmfoundry.utils.huggingface_hub_utils import (_flatten_import,
_remove_import)


def test_flatten_import_true():
Expand All @@ -14,3 +17,24 @@ def test_flatten_import_true():
def test_flatten_import_false():
node = ast.ImportFrom('y', ['x', 'y', 'z'])
assert not _flatten_import(node, ('x', 'z'))


@pytest.mark.parametrize('prefix_to_remove,expected_imports_remaining',
[('llmfoundry', 1), ('llmfoundry.utils', 2)])
def test_remove_imports(prefix_to_remove: str, expected_imports_remaining: int):
source_code = """
from llmfoundry import a
from llmfoundry.utils import b
from other_package import c
"""

tree = ast.parse(source_code)
assert len(tree.body) == 3

imports_kept = 0
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and not _remove_import(
node, [prefix_to_remove]):
imports_kept += 1

assert imports_kept == expected_imports_remaining

0 comments on commit 2431730

Please sign in to comment.