Skip to content

Commit

Permalink
Metrics registry (#1052)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Mar 24, 2024
1 parent 94a05bd commit 813d596
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 90 deletions.
63 changes: 63 additions & 0 deletions llmfoundry/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,71 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy, MaskedAccuracy)
from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity

from llmfoundry.metrics.token_acc import TokenAccuracy
from llmfoundry.registry import metrics

metrics.register('token_accuracy', func=TokenAccuracy)
metrics.register('lm_accuracy', func=InContextLearningLMAccuracy)
metrics.register('lm_expected_calibration_error',
func=InContextLearningLMExpectedCalibrationError)
metrics.register('mc_expected_calibration_error',
func=InContextLearningMCExpectedCalibrationError)
metrics.register('mc_accuracy', func=InContextLearningMultipleChoiceAccuracy)
metrics.register('qa_accuracy', func=InContextLearningQAAccuracy)
metrics.register('code_eval_accuracy', func=InContextLearningCodeEvalAccuracy)
metrics.register('language_cross_entropy', func=LanguageCrossEntropy)
metrics.register('language_perplexity', func=LanguagePerplexity)
metrics.register('masked_accuracy', func=MaskedAccuracy)

DEFAULT_CAUSAL_LM_TRAIN_METRICS = [
'language_cross_entropy',
'language_perplexity',
'token_accuracy',
]

DEFAULT_CAUSAL_LM_EVAL_METRICS = [
'language_cross_entropy',
'language_perplexity',
'token_accuracy',
'lm_accuracy',
'lm_expected_calibration_error',
'mc_expected_calibration_error',
'mc_accuracy',
'qa_accuracy',
'code_eval_accuracy',
]

DEFAULT_PREFIX_LM_METRICS = [
'language_cross_entropy',
'masked_accuracy',
]

DEFAULT_ENC_DEC_METRICS = [
'language_cross_entropy',
'masked_accuracy',
]

__all__ = [
'TokenAccuracy',
'InContextLearningLMAccuracy',
'InContextLearningLMExpectedCalibrationError',
'InContextLearningMCExpectedCalibrationError',
'InContextLearningMultipleChoiceAccuracy',
'InContextLearningQAAccuracy',
'InContextLearningCodeEvalAccuracy',
'LanguageCrossEntropy',
'LanguagePerplexity',
'MaskedAccuracy',
'DEFAULT_CAUSAL_LM_TRAIN_METRICS',
'DEFAULT_CAUSAL_LM_EVAL_METRICS',
'DEFAULT_PREFIX_LM_METRICS',
'DEFAULT_ENC_DEC_METRICS',
]
37 changes: 12 additions & 25 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, Mapping

# required for loading a python model into composer
from composer.metrics.nlp import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)

from llmfoundry.metrics import TokenAccuracy
from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.attention import is_flash_v2_installed
Expand Down Expand Up @@ -72,6 +65,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):

def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):
from llmfoundry.utils.builders import build_metric

pretrained_model_name_or_path = om_model_config.pretrained_model_name_or_path
pretrained_lora_id_or_path = om_model_config.get(
'pretrained_lora_id_or_path', None)
Expand Down Expand Up @@ -123,25 +118,17 @@ def __init__(self, om_model_config: DictConfig,
'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.'
)

# Set up training and eval metrics
use_train_metrics = om_model_config.get('use_train_metrics', True)
train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + om_model_config.get(
'additional_train_metrics', [])
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy()
]
build_metric(metric, {}) for metric in train_metric_names
] if use_train_metrics else []
eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + om_model_config.get(
'additional_eval_metrics', [])
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
build_metric(metric, {}) for metric in eval_metric_names
]
if not om_model_config.get('use_train_metrics', True):
train_metrics = []

# Construct the Hugging Face config to use
config = AutoConfig.from_pretrained(
Expand Down
11 changes: 5 additions & 6 deletions llmfoundry/models/hf/hf_prefix_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from typing import Mapping, MutableMapping

from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM,
PreTrainedTokenizerBase)

from llmfoundry.metrics import DEFAULT_PREFIX_LM_METRICS
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
Expand All @@ -22,9 +22,6 @@

__all__ = ['ComposerHFPrefixLM']

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100


class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a Prefix LM.
Expand Down Expand Up @@ -68,6 +65,8 @@ class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):

def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):
from llmfoundry.utils.builders import build_metric

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=om_model_config.get('trust_remote_code', True),
Expand Down Expand Up @@ -130,8 +129,8 @@ def __init__(self, om_model_config: DictConfig,
model = convert_hf_causal_lm_to_prefix_lm(model)

metrics = [
LanguageCrossEntropy(ignore_index=_HF_IGNORE_INDEX),
MaskedAccuracy(ignore_index=_HF_IGNORE_INDEX)
build_metric(metric, {}) for metric in DEFAULT_PREFIX_LM_METRICS +
om_model_config.get('additional_train_metrics', [])
]

composer_model = super().__init__(model=model,
Expand Down
11 changes: 5 additions & 6 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from typing import Mapping

from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, PreTrainedTokenizerBase,
T5ForConditionalGeneration)

from llmfoundry.metrics import DEFAULT_ENC_DEC_METRICS
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
Expand All @@ -21,9 +21,6 @@

__all__ = ['ComposerHFT5']

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100


@experimental('ComposerHFT5')
class ComposerHFT5(HuggingFaceModelWithZLoss):
Expand Down Expand Up @@ -59,6 +56,8 @@ class ComposerHFT5(HuggingFaceModelWithZLoss):

def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):
from llmfoundry.utils.builders import build_metric

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=om_model_config.get('trust_remote_code', True),
Expand Down Expand Up @@ -122,8 +121,8 @@ def __init__(self, om_model_config: DictConfig,
f'init_device="{init_device}" must be either "cpu" or "meta".')

metrics = [
LanguageCrossEntropy(ignore_index=_HF_IGNORE_INDEX),
MaskedAccuracy(ignore_index=_HF_IGNORE_INDEX)
build_metric(metric, {}) for metric in DEFAULT_ENC_DEC_METRICS +
om_model_config.get('additional_train_metrics', [])
]

composer_model = super().__init__(model=model,
Expand Down
19 changes: 5 additions & 14 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,22 @@
import torch
from composer.core.types import Batch
from composer.metrics import InContextLearningMetric
from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.models import ComposerModel
from torchmetrics import Metric
from transformers import AutoTokenizer

from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS
from llmfoundry.utils.builders import build_metric


class InferenceAPIEvalWrapper(ComposerModel):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
self.tokenizer = tokenizer
self.labels = None
# set up training and eval metrics
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
build_metric(metric, {})
for metric in DEFAULT_CAUSAL_LM_EVAL_METRICS
]
self.eval_metrics = {
metric.__class__.__name__: metric for metric in eval_metrics
Expand Down
32 changes: 12 additions & 20 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics import (InContextLearningCodeEvalAccuracy,
InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy)
from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity
from composer.models import HuggingFaceModel
from composer.utils import dist

from llmfoundry.metrics import TokenAccuracy
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

Expand Down Expand Up @@ -1034,27 +1026,27 @@ def __init__(
om_model_config: DictConfig,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
):
from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
from llmfoundry.utils.builders import build_metric

resolved_om_model_config = om.to_container(om_model_config,
resolve=True)
assert isinstance(resolved_om_model_config, dict)

hf_config = MPTConfig.from_dict(resolved_om_model_config)
model = MPTForCausalLM(hf_config)

use_train_metrics = om_model_config.get('use_train_metrics', True)
train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + resolved_om_model_config.get(
'additional_train_metrics', [])
train_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy()
build_metric(metric, {}) for metric in train_metric_names
] if use_train_metrics else []
eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + resolved_om_model_config.get(
'additional_eval_metrics', [])
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
TokenAccuracy(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningCodeEvalAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError(),
build_metric(metric, {}) for metric in eval_metric_names
]

super().__init__(
Expand Down
9 changes: 9 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from composer.loggers import LoggerDestination
from composer.optim import ComposerScheduler
from torch.optim import Optimizer
from torchmetrics import Metric

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.registry_utils import create_registry
Expand Down Expand Up @@ -64,11 +65,19 @@
entry_points=True,
description=_schedulers_description)

_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface."""
metrics = create_registry('llmfoundry',
'metrics',
generic_type=Type[Metric],
entry_points=True,
description=_metrics_description)

__all__ = [
'loggers',
'callbacks',
'callbacks_with_config',
'optimizers',
'algorithms',
'schedulers',
'metrics',
]
10 changes: 2 additions & 8 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.utils.builders import (build_algorithm, build_callback,
build_evaluators,
build_icl_data_and_gauntlet,
build_icl_evaluators, build_logger,
build_optimizer, build_scheduler,
build_tokenizer)
build_logger, build_optimizer,
build_scheduler, build_tokenizer)
from llmfoundry.utils.checkpoint_conversion_helpers import (
convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict,
load_tokenizer)
Expand Down Expand Up @@ -34,9 +31,6 @@
__all__ = [
'build_algorithm',
'build_callback',
'build_evaluators',
'build_icl_data_and_gauntlet',
'build_icl_evaluators',
'build_logger',
'build_optimizer',
'build_scheduler',
Expand Down
Loading

0 comments on commit 813d596

Please sign in to comment.