Skip to content

Commit

Permalink
Merge branch 'migrate_subclasses_to_foundry' of github.com:mosaicml/l…
Browse files Browse the repository at this point in the history
…lm-foundry into migrate_subclasses_to_foundry
  • Loading branch information
maxisawesome committed Apr 10, 2024
2 parents a5082b0 + d7272b1 commit 3c8ac56
Show file tree
Hide file tree
Showing 39 changed files with 466 additions and 310 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ dependencies = [
"llm-foundry",
]

[project.entry-points."llm_foundry.loggers"]
[project.entry-points."llmfoundry_loggers"]
my_logger = "foundry_registry.loggers:MyLogger"
```

Expand Down
5 changes: 1 addition & 4 deletions llmfoundry/eval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Natively supported datasets."""
"""Natively supported in-context learning evaluation datasets."""

from llmfoundry.eval.datasets.in_context_learning_evaluation import (
InContextLearningCodeEvalDataset, InContextLearningDataset,
Expand Down
15 changes: 7 additions & 8 deletions llmfoundry/eval/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
# This code is based on the implementation in https://github.com/EleutherAI/lm-evaluation-harness/blob/8c048e266a22a1c85ccbdb0c209ac712e4f39989/lm_eval/base.py#L221-L330

from __future__ import annotations

import copy
Expand Down Expand Up @@ -97,14 +93,17 @@ class InContextLearningDataset(Dataset):
strip_dataset (bool): Boolean for whether to strip whitespace from data. Trailing whitespace can cause degenerative outputs,
so unless whitespace should be preserved (for example in code), this should be set to True.
padding_side (str): Side of the content and answer on which to apply padding. Can be either 'right' or 'left'.
tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses.
padding_size (int): The final size of the tensor after padding. Defaults to max_sequence_length.
base_batch (Dict): The base dictionary upon which a batch is created. See above for more details.
base_mapping (Dict): A mapping of batch keys to dataset columns, used to create batches. See above for more details.
hf_loading_vars (Dict): A dictionary containing keyword arguments to be passed into `load_dataset` if dataset is being pulled from HF.
hf_parsing_map (Dict): A dictionary containing a mapping from HF columns to ICL dataset keys. The dictionary should be formatted {icl_key:[hf_key1, hf_key1]}.
Column contents will be concatenated with ' ' seperating them. If not included, will load the columns already present in the HF dataset.
tokenize_labels (bool): Whether or not the labels should be tokenized. Generally determined by which metric a dataset uses.
generation_kwargs (Dict): A dictionary containing keyword arguments to be passed along to the model's generate function.
static_keys (List): A list of the key values which will be broadcast across a batch (e.g. it is the same for each batch element).
list_keys (List): A list of the batch keys whose values are lists which will be split using list methods during calls to split_batch.
tensor_keys (List): A list of the batch keys whose values are tensors which will be split using tensor methods during calls to split_batch.
"""

def __init__(
Expand All @@ -125,15 +124,15 @@ def __init__(
strip_dataset: bool = True,
padding_side: str = 'right',
tokenize_labels: bool = True,
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
tensor_keys: Optional[List] = None,
padding_size: Optional[int] = None,
base_batch: Optional[Dict] = None,
batch_mapping: Optional[Dict] = None,
hf_loading_vars: Optional[Dict] = None,
hf_parsing_map: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = None,
static_keys: Optional[List] = None,
list_keys: Optional[List] = None,
tensor_keys: Optional[List] = None,
):
self.tokenizer = tokenizer
self.prefix_space = tokenizer_needs_prefix_space(self.tokenizer)
Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/eval/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utility and helper functions for datasets."""
from __future__ import annotations

Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A collection of common torchmetrics."""

from llmfoundry.eval.metrics.nlp import (
Expand Down
20 changes: 20 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

__all__ = [
'norms',
]
17 changes: 14 additions & 3 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, Mapping

# required for loading a python model into composer
from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig,
PreTrainedModel, PreTrainedTokenizerBase)

from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
Expand Down Expand Up @@ -162,6 +161,18 @@ def _autoset_attn_implementation_monkeypatch(
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
elif isinstance(attr, PretrainedConfig):
if not isinstance(v, Mapping):
raise ValueError(
f'Expected a dictionary for config override {k}, but got {v}.'
)

for _k, _v in v.items():
if not hasattr(attr, _k):
raise ValueError(
f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).'
)
setattr(attr, _k, _v)
else:
setattr(config, k, v)

Expand Down
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
'scaled_multihead_dot_product_attention',
Expand All @@ -23,7 +23,6 @@
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -419,12 +419,19 @@ def __init__(
self.Wqkv._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
norm_size = self.head_dim if qk_gn else d_model
self.q_ln = norm_class(norm_size, device=device)
self.q_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = norm_class(norm_size, device=device)
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -72,7 +72,6 @@ def __init__(
del kwargs # unused, just to capture any extra args from the config
super().__init__()

norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

Expand All @@ -88,7 +87,11 @@ def __init__(
if k not in args_to_exclude_in_attn_class
}

self.norm_1 = norm_class(d_model, device=device)
self.norm_1 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
Expand All @@ -100,7 +103,11 @@ def __init__(
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
self.norm_2 = norm_class(d_model, device=device)
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
Expand Down
25 changes: 25 additions & 0 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.utils.registry_utils import construct_from_registry


def build_norm(
name: str,
normalized_shape: Union[int, List[int], torch.Size],
device: Optional[str] = None,
):
kwargs = {
'normalized_shape': normalized_shape,
'device': device,
}

return construct_from_registry(name=name,
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
19 changes: 9 additions & 10 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional, Type, Union
from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms

norms.register(name='layernorm', func=torch.nn.LayerNorm)


def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_autocast_enabled():
Expand All @@ -18,6 +22,7 @@ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
return tensor


@norms.register_class('low_precision_layernorm')
class LPLayerNorm(torch.nn.LayerNorm):

def __init__(
Expand Down Expand Up @@ -62,6 +67,7 @@ def rms_norm(x: torch.Tensor,
return output


@norms.register_class('rmsnorm')
class RMSNorm(torch.nn.Module):

def __init__(
Expand All @@ -84,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)


@norms.register_class('low_precision_rmsnorm')
class LPRMSNorm(RMSNorm):

def __init__(
Expand Down Expand Up @@ -111,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.eps).to(dtype=x.dtype)


@norms.register_class('triton_rmsnorm')
class TritonRMSNorm(torch.nn.Module):

def __init__(
Expand Down Expand Up @@ -150,12 +158,3 @@ def forward(self, x: torch.Tensor):
prenorm=False,
residual_in_fp32=False,
)


NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
'low_precision_rmsnorm': LPRMSNorm,
'triton_rmsnorm': TritonRMSNorm,
}
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
Expand Down
14 changes: 9 additions & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand All @@ -42,11 +41,13 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand Down Expand Up @@ -297,12 +298,11 @@ def __init__(self, config: MPTConfig):
else:
config.init_device = 'meta'

if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
if config.norm_type.lower() not in norms.get_all():
norm_options = ' | '.join(norms.get_all())
raise NotImplementedError(
f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
)
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]

# CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
# both report this helping with stabilizing training
Expand All @@ -329,7 +329,11 @@ def __init__(self, config: MPTConfig):
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)

self.norm_f = norm_class(config.d_model, device=config.init_device)
self.norm_f = build_norm(
name=config.norm_type.lower(),
normalized_shape=config.d_model,
device=config.init_device,
)

self.rope = config.attn_config['rope']
self.rope_impl = None
Expand Down
Loading

0 comments on commit 3c8ac56

Please sign in to comment.