Skip to content

Commit

Permalink
Merge branch 'main' into mlf-sig
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Feb 6, 2024
2 parents 8244ab6 + 6591f48 commit e06ac09
Show file tree
Hide file tree
Showing 23 changed files with 11,928 additions and 706 deletions.
4 changes: 0 additions & 4 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.generate_callback import Generate
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
Expand All @@ -21,13 +19,11 @@

__all__ = [
'FDiffMetrics',
'Generate',
'MonolithicCheckpointSaver',
'GlobalLRScaling',
'LayerFreezing',
'ScheduledGarbageCollector',
'EvalGauntlet',
'ModelGauntlet',
'HuggingFaceCheckpointer',
'AsyncEval',
]
32 changes: 0 additions & 32 deletions llmfoundry/callbacks/generate_callback.py

This file was deleted.

21 changes: 0 additions & 21 deletions llmfoundry/callbacks/model_gauntlet_callback.py

This file was deleted.

106 changes: 0 additions & 106 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple

Expand All @@ -11,8 +10,6 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -428,106 +425,3 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
padding, waste = profile(raw_batch_size)
yield (packing_ratio, padding, waste)


if __name__ == '__main__':

import warnings

warnings.warn(
VersionedDeprecationWarning(
'Please use scripts/misc/profile_packing.py to profile packing.',
remove_version='0.5.0',
))

import os
from argparse import ArgumentParser, Namespace

from omegaconf import OmegaConf as om

from llmfoundry.utils import build_tokenizer

def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
description=
'Profile packing_ratio choices for a particular workload.')
parser.add_argument(
'--yaml-path',
type=str,
required=True,
help='Path to the YAML that defines the workload to profile.')
parser.add_argument('--num-devices',
type=int,
default=None,
help='How many devices your run will use.')
parser.add_argument('--min',
type=float,
required=True,
help='Smallest packing_ratio to test. Must be >=1.')
parser.add_argument(
'--max',
type=float,
required=True,
help='Largest packing_ratio to test. Must be larger than `min`.')
parser.add_argument(
'--num-packing-ratios',
type=int,
default=20,
help=
'Number of packing_ratio values (spaced between `min` and `max) to try.'
)

args = parser.parse_args()

if not os.path.isfile(args.yaml_path):
raise FileNotFoundError(
'`yaml_path` does not correspond to any existing file.')
if args.num_devices < 1:
raise ValueError('`num_devices` must be a positive integer.')
if args.min < 1.0:
raise ValueError('`min` must be >=1.0.')
if args.max < args.min:
raise ValueError('`max` cannot be less than `min`.')
if args.num_packing_ratios < 1:
raise ValueError('`num_packing_ratios` must be a positive integer.')
return args

args = parse_args()

with open(args.yaml_path) as f:
cfg = om.load(f)
if 'parameters' in cfg:
cfg = om.to_container(cfg.parameters)
cfg = om.create(cfg)
device_batch_size = cfg.global_train_batch_size // args.num_devices

# Fetch a bunch of raw examples once, which we'll re-use
if 'train_loader' not in cfg:
raise ValueError('config must define train_loader')
dataloader_cfg = cfg.train_loader

# build tokenizer
if 'tokenizer' not in cfg:
raise ValueError('config must define tokenizer')

resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
if not isinstance(resolved_tokenizer_cfg, Dict):
raise ValueError(
'tokenizer config needs to be resolved by omegaconf into a Dict.')
tokenizer_cfg = resolved_tokenizer_cfg

tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max,
args.num_packing_ratios, device_batch_size)

header = '\n\n\n packing_ratio | % PADDING | % WASTE'
fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%'

print(header)
print('-' * len(header))
for packing_ratio, padding, waste in results:
print(fstr.format(packing_ratio, padding, waste))
13 changes: 0 additions & 13 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ def __init__(self,
batching_method: str = 'random',
**kwargs: Any):

group_method = kwargs.pop('group_method', None)
if group_method is not None:
raise NotImplementedError(
'group_method is deprecated and has been removed.\nTo ' +
'concatenate, use the --concat_tokens ' +
'argument when creating your MDS dataset with concat_c4.py')

if len(kwargs) > 0:
raise ValueError(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}'
Expand Down Expand Up @@ -245,12 +238,6 @@ def build_text_dataloader(
device_batch_size: int,
) -> DataSpec:
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'
if cfg.dataset.get('group_method', None) is not None:
raise NotImplementedError(
'group_method is deprecated and has been removed.\nTo ' +
'concatenate, use the --concat_tokens ' +
'argument when creating your MDS dataset with convert_dataset_hf.py'
)

# get kwargs
streams_dict = cfg.dataset.pop('streams', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
# TODO: this will be deprecated
self.generate_completion = lambda prompt, num_tokens: self.client.completions.create(
model=self.model_name,
prompt=prompt,
Expand Down
53 changes: 3 additions & 50 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.utils.warnings import VersionedDeprecationWarning


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -91,7 +90,7 @@ def scaled_multihead_dot_product_attention(
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
kv_n_heads: int,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
Expand All @@ -100,23 +99,8 @@ def scaled_multihead_dot_product_attention(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
if multiquery:
warnings.warn(
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
Expand Down Expand Up @@ -221,7 +205,7 @@ def flash_attn_fn(
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
kv_n_heads: int,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -250,21 +234,6 @@ def flash_attn_fn(

check_valid_inputs(query, key, value)

if multiquery:
warnings.warn(
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

if past_key_value is not None:
if len(past_key_value) != 0:
key = torch.cat([past_key_value[0], key], dim=1)
Expand Down Expand Up @@ -384,7 +353,7 @@ def triton_flash_attn_fn(
key: torch.Tensor,
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
kv_n_heads: int,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
Expand All @@ -393,7 +362,6 @@ def triton_flash_attn_fn(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
Expand Down Expand Up @@ -425,21 +393,6 @@ def triton_flash_attn_fn(

check_valid_inputs(query, key, value)

if multiquery:
warnings.warn(
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

if past_key_value is not None:
if len(past_key_value) != 0:
key = torch.cat([past_key_value[0], key], dim=1)
Expand Down
4 changes: 0 additions & 4 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def llama_attention_patch_torch(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down Expand Up @@ -188,8 +186,6 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
# Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
Expand Down
8 changes: 0 additions & 8 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
"""The MPT configuration class.
Expand Down Expand Up @@ -116,7 +115,6 @@ def __init__(
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): Deprecated.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
use_cache (bool): Whether or not the model should return the last key/values attentions
Expand Down Expand Up @@ -159,12 +157,6 @@ def __init__(
self.init_config = init_config
self.fc_type = fc_type
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
if verbose is not None:
warnings.warn(
VersionedDeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.',
remove_version='0.5.0',
))

if 'name' in kwargs:
del kwargs['name']
Expand Down
Loading

0 comments on commit e06ac09

Please sign in to comment.