From d7c78229e91129d4c35006209fabd5fb2f2252e9 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Sun, 22 Sep 2024 14:03:42 -0400 Subject: [PATCH 01/15] Fix reuse kv cache for torch attention (#1539) --- llmfoundry/models/layers/attention.py | 3 +++ tests/models/layers/test_flash_torch.py | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index a1af2235cf..625327767e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -656,6 +656,9 @@ def get_qkv( 'prev_layer_key_value is None, cannot reuse_prev_layer_kv.', ) key, value = prev_layer_key_value + if self.attn_impl == 'torch': + key = rearrange(key, 'b h d s -> b s (h d)') + value = rearrange(value, 'b h s d -> b s (h d)') query = self.Wq(x) if self.clip_qkv: diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 01a6a7576d..0a4b32a73a 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -188,7 +188,7 @@ def gen_bias(attn_impl: str): alibi=alibi, alibi_bias_max=8, ) - if attn_impl != 'flash' and attn_uses_sequence_id and sequence_id is not None: + if attn_impl == 'torch' and attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) # pyright attn_bias = apply_sequence_id( attn_bias, @@ -561,8 +561,10 @@ def test_grouped_query_invalid_heads(): }, }], ) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) def test_reuse_prev_layer_kv_cache( pos_emb_config: dict, + attn_impl: str, device: str = 'cuda', ): """Checks reusing previous layer's kv cache.""" @@ -570,7 +572,7 @@ def test_reuse_prev_layer_kv_cache( rope = pos_emb_config['rope'] cfg = { - 'attn_impl': 'flash', + 'attn_impl': attn_impl, 'd_model': 64, 'n_heads': 4, 'attn_pdrop': 0, @@ -630,6 +632,13 @@ def gen_bias(attn_impl: str): alibi=alibi, alibi_bias_max=8, ) + if attn_impl == 'torch': + assert isinstance(attn_bias, torch.Tensor) # pyright + attn_bias = apply_sequence_id( + attn_bias, + sequence_id, # type: ignore + s, + ) return attn_bias @@ -637,7 +646,7 @@ def gen_bias(attn_impl: str): sequence_id=sequence_id, S=s, attn_uses_sequence_id=True, - attn_impl='flash', + attn_impl=attn_impl, attention_mask=attention_mask, ) @@ -656,7 +665,7 @@ def gen_bias(attn_impl: str): x1.requires_grad = True with torch.autocast(x0.device.type): - attn_bias_0 = gen_bias('flash') + attn_bias_0 = gen_bias(attn_impl) alibi_slopes_0 = None if alibi: alibi_slopes_0 = gen_slopes( @@ -703,7 +712,7 @@ def gen_bias(attn_impl: str): flash_attn_padding_info=flash_attn_padding_info, alibi_slopes=alibi_slopes_0, ) - attn_bias_1 = gen_bias('flash') + attn_bias_1 = gen_bias(attn_impl) alibi_slopes_1 = None if alibi: alibi_slopes_1 = gen_slopes( From 14cff668750dc08eb4511ddee0d55b127e711dea Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Sun, 22 Sep 2024 19:49:21 -0400 Subject: [PATCH 02/15] Error on text dataset file not found (#1534) --- .../data_prep/convert_text_to_mds.py | 15 ++++++++++----- llmfoundry/utils/exceptions.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py index 9a1f8a912d..3ea5aeb5d4 100644 --- a/llmfoundry/command_utils/data_prep/convert_text_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_text_to_mds.py @@ -32,6 +32,7 @@ CannotUnicodeDecodeFile, DatasetTooSmallError, InputFolderMissingDataError, + InputFolderNotFound, OutputFolderNotEmptyError, ) @@ -125,11 +126,15 @@ def get_object_names(input_folder: str) -> list[str]: object_store = maybe_create_object_store_from_uri(input_folder) if object_store is not None: _, _, folder_prefix = parse_uri(input_folder) - names = [ - name for name in object_store.list_objects(folder_prefix) - if name.endswith('.txt') - ] - log.info(f'Found {len(names)} text files in remote storage') + try: + names = [ + name for name in object_store.list_objects(folder_prefix) + if name.endswith('.txt') + ] + log.info(f'Found {len(names)} text files in remote storage') + except FileNotFoundError: + raise InputFolderNotFound(folder_prefix) + else: # input_folder is a local folder names = [ diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 11895564f2..900355dff5 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -348,6 +348,17 @@ def __init__(self, input_folder: str) -> None: super().__init__(message, input_folder=input_folder) +class InputFolderNotFound(UserError): + """Error thrown when the a folder is not found.""" + + def __init__(self, folder_that_was_not_found: str) -> None: + message = f'{folder_that_was_not_found} not found.' + super().__init__( + message, + folder_that_was_not_found=folder_that_was_not_found, + ) + + class CannotUnicodeDecodeFile(UserError): """Error thrown when the input folder is missing data.""" From a2c0507795a887b6fb71d3ef975b714523fe2abb Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Sun, 22 Sep 2024 18:23:51 -0700 Subject: [PATCH 03/15] Make ICL tasks not required for eval (#1540) --- llmfoundry/command_utils/eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index e644ad1f0f..70c4319ea8 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -262,7 +262,7 @@ def evaluate(cfg: DictConfig) -> tuple[list[Trainer], pd.DataFrame]: EvalConfig, EVAL_CONFIG_KEYS, transforms=[allow_toplevel_keys], - icl_tasks_required=True, + icl_tasks_required=False, ) model_configs = eval_config.models @@ -273,7 +273,7 @@ def evaluate(cfg: DictConfig) -> tuple[list[Trainer], pd.DataFrame]: # Mandatory Evaluation Parameters icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str if icl_tasks is None: - raise ValueError('icl_tasks must be specified in the config') + icl_tasks = [] # Optional Evaluation Parameters with default values eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders From 85403c086710bc0f62d03fc03c0fcbb2e5ffda1d Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 23 Sep 2024 10:37:26 -0400 Subject: [PATCH 04/15] Bumping flash attention version to 2.6.3 and adding option for softcap in attention and lm_head logits. (#1374) --- llmfoundry/models/layers/attention.py | 24 +++++- llmfoundry/models/mpt/configuration_mpt.py | 14 +++ llmfoundry/models/mpt/modeling_mpt.py | 6 ++ llmfoundry/models/utils/config_defaults.py | 1 + setup.py | 2 +- tests/models/layers/test_flash_attn.py | 99 +++++++++++++++++++++- 6 files changed, 140 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 625327767e..612d6b9642 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -112,6 +112,7 @@ def scaled_multihead_dot_product_attention( dropout_p: float = 0.0, training: bool = False, needs_weights: bool = False, + attn_logit_softcapping: Optional[float] = None, sliding_window_size: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: @@ -149,6 +150,11 @@ def scaled_multihead_dot_product_attention( attn_weight = q.matmul(k) * softmax_scale + if attn_logit_softcapping is not None: + attn_weight = attn_logit_softcapping * torch.tanh( + attn_weight / attn_logit_softcapping, + ) + if attn_bias is not None: # clamp to 0 necessary for torch 2.0 compile() _s_q = max(0, attn_bias.size(2) - s_q) @@ -264,6 +270,7 @@ def flash_attn_fn( sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, + attn_logit_softcapping: Optional[float] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: if key_padding_mask is not None: @@ -381,13 +388,17 @@ def flash_attn_fn( return_attn_probs=needs_weights, ) elif is_flash_v2_installed(): - alibi_kwargs = {} + extra_attn_kwargs = {} if check_alibi_support('flash'): - alibi_kwargs = {'alibi_slopes': alibi_slopes} + extra_attn_kwargs['alibi_slopes'] = alibi_slopes elif alibi_slopes is not None: raise ValueError( 'alibi_slopes is only supported for flash-attn>=2.4.2', ) + if is_flash_v2_installed( + v2_version='v2.6.2', + ) and attn_logit_softcapping is not None: + extra_attn_kwargs['softcap'] = attn_logit_softcapping output_unpad = flash_attn_interface.flash_attn_varlen_func( q=query_unpad, k=key_unpad, @@ -401,7 +412,7 @@ def flash_attn_fn( causal=reset_is_causal, return_attn_probs=needs_weights, window_size=(sliding_window_size, sliding_window_size), - **alibi_kwargs, + **extra_attn_kwargs, ) else: raise RuntimeError( @@ -448,6 +459,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, ): super().__init__() @@ -463,6 +475,7 @@ def __init__( self.kv_n_heads = kv_n_heads self.sliding_window_size = sliding_window_size self.reuse_kv_layer_idx = reuse_kv_layer_idx + self.attn_logit_softcapping = attn_logit_softcapping self.kv_dim = kv_dim if kv_dim is not None else self.d_model self.head_dim = d_model // n_heads @@ -625,6 +638,7 @@ def forward( dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, + attn_logit_softcapping=self.attn_logit_softcapping, sliding_window_size=self.sliding_window_size, **extra_attn_kwargs, ) @@ -853,6 +867,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, ): super().__init__( @@ -873,6 +888,7 @@ def __init__( bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, + attn_logit_softcapping=attn_logit_softcapping, kv_dim=kv_dim, ) @@ -902,6 +918,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, ): super().__init__( @@ -922,6 +939,7 @@ def __init__( bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, + attn_logit_softcapping=attn_logit_softcapping, kv_dim=kv_dim, ) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 91b431e3b4..dbcabdf5f9 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -51,6 +51,7 @@ def __init__( tie_word_embeddings: bool = True, use_pad_tok_in_ffn: bool = True, block_overrides: Optional[dict[str, Any]] = None, + final_logit_softcapping: Optional[float] = None, **kwargs: Any, ): """The MPT configuration class. @@ -148,6 +149,7 @@ def __init__( reuse_kv_layer: attn_config: reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse + final_logit_softcapping (float | None): Softcapping threshold for final logit. Set to None to disable (default value None). Please see https://arxiv.org/pdf/2403.08295 for more details. kwargs (Any): Other relevant keyword arguments. """ self.d_model = d_model @@ -181,6 +183,7 @@ def __init__( if block_overrides is not None: self._validate_block_overrides(block_overrides) self.block_overrides = block_overrides + self.final_logit_softcapping = final_logit_softcapping if isinstance(fc_type, str): fc_type = {'name': fc_type} @@ -325,6 +328,17 @@ def _validate_config(self) -> None: raise NotImplementedError( 'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).', ) + if self.attn_config['attn_logit_softcapping'] is not None: + if self.attn_config['attn_logit_softcapping'] <= 0: + raise ValueError( + 'Attention attn_logit_softcapping should be positive.', + ) + if self.attn_config[ + 'attn_impl' + ] == 'flash' and not is_flash_v2_installed(v2_version='v2.6.2',): + raise NotImplementedError( + 'Attention attn_logit_softcapping is only implemented with torch attention or flash attention v2.6.2 (or higher).', + ) if self.attn_config['kv_dim'] is not None and self.attn_config[ 'fused_qkv']: raise ValueError( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index cfe1172634..9212f5594d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1071,6 +1071,7 @@ def __init__(self, config: MPTConfig): f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.", ) self.logit_scale = logit_scale + self.final_logit_softcapping = config.final_logit_softcapping @property def backbone_model_class(self) -> type[MPTModel]: @@ -1172,6 +1173,11 @@ def forward( ) logits *= self.logit_scale + if self.final_logit_softcapping is not None: + logits = self.final_logit_softcapping * torch.tanh( + logits / self.final_logit_softcapping, + ) + loss = None if labels is not None: _labels = torch.roll(labels, shifts=-1) diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index bd3b29a479..5550785149 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -18,6 +18,7 @@ 'softmax_scale': None, 'attn_uses_sequence_id': False, 'sliding_window_size': -1, + 'attn_logit_softcapping': None, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, diff --git a/setup.py b/setup.py index 0a75c610b8..ebc66fdacf 100644 --- a/setup.py +++ b/setup.py @@ -104,7 +104,7 @@ # Flash 2 group kept for backwards compatibility extra_deps['gpu-flash2'] = [ - 'flash-attn>=2.5.8,<3', + 'flash-attn>=2.6.3,<3', ] extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2']) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 987ea7160a..666d93c9b4 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +from typing import Optional import pytest import torch @@ -334,5 +335,99 @@ def gen_bias(): _assert_approx_equal(value_1.grad, value_2.grad) -def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor): - assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2) +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.6.2'), + reason= + 'attn_logit_softcapping only supported by Flash Attention after v2.6.2.', +) +@pytest.mark.parametrize( + 'attn_logit_softcapping', + [None, 0.1, 1.0, 10.0, 100.0], +) +def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]): + # Test that attn_logit_softcapping in attention works as expected. + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + seqlen_1 = 8 + bsz = 2 + n_heads = 4 + + query_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + value_1.requires_grad = True + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + should_repeat_kv_for_gqa=True, + attn_logit_softcapping=attn_logit_softcapping, + ) + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + output_2, _, _ = scaled_multihead_dot_product_attention( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + attn_logit_softcapping=attn_logit_softcapping, + ) + output_2.sum().backward() + + _assert_approx_equal(output_1, output_2) + assert (query_2.grad is not None) and (query_1.grad is not None) + _assert_approx_equal(query_1.grad, query_2.grad) + assert (key_2.grad is not None) and (key_1.grad is not None) + _assert_approx_equal(key_1.grad, key_2.grad) + assert (value_2.grad is not None) and (value_1.grad is not None) + _assert_approx_equal(value_1.grad, value_2.grad) + + +def _assert_approx_equal( + value1: torch.Tensor, + value2: torch.Tensor, + atol: float = 1e-2, + rtol: float = 1e-2, +): + actual_difference = torch.norm(value2 - value1) + allowed_difference = atol + rtol * torch.norm(value2) + assert actual_difference < allowed_difference, f'{actual_difference=}, {allowed_difference=}' From f377090dec102afc646fb29a4510ded6ae74ecf9 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:00:07 -0700 Subject: [PATCH 05/15] Register mosaic logger (#1542) --- llmfoundry/loggers/__init__.py | 2 ++ tests/loggers/test_mosaic_ml_logger.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 tests/loggers/test_mosaic_ml_logger.py diff --git a/llmfoundry/loggers/__init__.py b/llmfoundry/loggers/__init__.py index cd3f3fdc62..c60d9be2cd 100644 --- a/llmfoundry/loggers/__init__.py +++ b/llmfoundry/loggers/__init__.py @@ -4,6 +4,7 @@ from composer.loggers import ( InMemoryLogger, MLFlowLogger, + MosaicMLLogger, TensorboardLogger, WandBLogger, ) @@ -18,3 +19,4 @@ func=InMemoryLogger, ) # for backwards compatibility loggers.register('mlflow', func=MLFlowLogger) +loggers.register('mosaicml', func=MosaicMLLogger) diff --git a/tests/loggers/test_mosaic_ml_logger.py b/tests/loggers/test_mosaic_ml_logger.py new file mode 100644 index 0000000000..e9c003321b --- /dev/null +++ b/tests/loggers/test_mosaic_ml_logger.py @@ -0,0 +1,16 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from composer.loggers import MosaicMLLogger + +from llmfoundry.utils.builders import build_logger + + +def test_mosaic_ml_logger_constructs(): + mosaic_ml_logger = build_logger( + 'mosaicml', + kwargs={'ignore_exceptions': True}, + ) + + assert isinstance(mosaic_ml_logger, MosaicMLLogger) + assert mosaic_ml_logger.ignore_exceptions == True From d85c83b15d5b07a1b8cd00eaa7e400aaf7b22ea7 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Mon, 23 Sep 2024 23:24:16 -0700 Subject: [PATCH 06/15] Hfcheckpointer optional generation config (#1543) Co-authored-by: v-chen_data --- llmfoundry/callbacks/hf_checkpointer.py | 7 ++- .../inference/test_convert_composer_to_hf.py | 56 ++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 65bdcb3b6c..4365a5b2e5 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -588,9 +588,10 @@ def tensor_hook( del new_base_model_instance else: new_model_instance = type(original_model)(new_config) - new_model_instance.generation_config.update( - **original_model.generation_config.to_dict(), - ) + if new_model_instance.generation_config is not None: + new_model_instance.generation_config.update( + **original_model.generation_config.to_dict(), + ) # Then load the state dict in with "assign" so that the state dict # is loaded properly even though the model is initially on meta device. diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 66ec739a65..bf5f2a970b 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -8,13 +8,14 @@ import pathlib import shutil from argparse import Namespace -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, Union, cast from unittest import mock from unittest.mock import ANY, MagicMock, patch import catalogue import pytest import torch +import torch.nn as nn import transformers from composer import ComposerModel, Trainer from composer.loggers import MLFlowLogger @@ -23,7 +24,13 @@ from omegaconf import OmegaConf as om from torch.distributed._tensor.api import DTensor from torch.utils.data import DataLoader -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + GenerationConfig, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename @@ -1637,3 +1644,48 @@ def test_license_file_finder( found_path = _maybe_get_license_filename(str(tmp_path)) assert (found_path == license_file_name ) if license_file_name is not None else (found_path is None) + + +@pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}]) +def test_generation_config_variants( + generation_config: Optional[Union[dict[str, Any], GenerationConfig]], +): + + class MockModel(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + # Ensure generation_config is always a GenerationConfig object + if isinstance(config.generation_config, dict): + self.generation_config = GenerationConfig( + **config.generation_config, + ) + else: + self.generation_config = config.generation_config + + config = AutoConfig.from_pretrained('gpt2') + # Convert dict to GenerationConfig if needed + if isinstance(generation_config, dict): + generation_config = GenerationConfig(**generation_config) + config.generation_config = generation_config + + mock_model = MockModel(config) + logger = MagicMock() + state = MagicMock() + state.timestamp.batch = 1 + state.is_model_ddp = False + state.model.model = mock_model + state.model.tokenizer = None + + checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + ) + + checkpointer._save_checkpoint( + state=state, + logger=logger, + upload_to_save_folder=False, + register_to_mlflow=False, + ) From 275a2a40d86a36882cc7963e2677628e05aaaf01 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:57:21 -0700 Subject: [PATCH 07/15] Bump composer version to 0.25.0 (#1546) --- setup.py | 8 ++++---- tests/a_scripts/inference/test_convert_composer_to_hf.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index ebc66fdacf..48c1326b0d 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.24.1,<0.25', + 'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.25.0,<0.26', 'mlflow>=2.14.1,<2.17', 'accelerate>=0.25,<0.34', # for HF inference `device_map` 'transformers>=4.43.2,<4.44', @@ -91,7 +91,7 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.24.1,<0.25', + 'mosaicml[databricks]>=0.25.0,<0.26', 'numpy<2', 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', @@ -99,7 +99,7 @@ ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.24.1,<0.25', + 'mosaicml[tensorboard]>=0.25.0,<0.26', ] # Flash 2 group kept for backwards compatibility @@ -110,7 +110,7 @@ extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2']) extra_deps['peft'] = [ - 'mosaicml[peft]>=0.24.1,<0.25', + 'mosaicml[peft]>=0.25.0,<0.26', ] extra_deps['openai'] = [ diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index bf5f2a970b..c25432dc48 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1563,6 +1563,8 @@ def test_mptmoe_huggingface_conversion_callback( # Check output equivalence loaded_model = loaded_model.cuda().bfloat16() # type: ignore + for k, v in batch.items(): + batch[k] = v.cuda() loaded_model_logits = loaded_model( input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), From 151a2e297b603d84e1e4dfed389c3494990936e6 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:53:05 -0700 Subject: [PATCH 08/15] Bump streaming version to 0.9.0 (#1550) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 48c1326b0d..d1979faf63 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ 'mlflow>=2.14.1,<2.17', 'accelerate>=0.25,<0.34', # for HF inference `device_map` 'transformers>=4.43.2,<4.44', - 'mosaicml-streaming>=0.8.1,<0.9', + 'mosaicml-streaming>=0.9.0,<0.10', 'torch>=2.4.0,<2.4.1', 'datasets>=2.19,<2.20', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data From 722526d420dab9adc5a5be18425d5e08c97ee0c8 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:25:27 -0700 Subject: [PATCH 09/15] Bump version to 0.13.0.dev0 (#1549) --- llmfoundry/_version.py | 2 +- llmfoundry/command_utils/eval.py | 2 +- llmfoundry/models/hf/model_wrapper.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/_version.py b/llmfoundry/_version.py index 2f1f590b19..0cddcaf967 100644 --- a/llmfoundry/_version.py +++ b/llmfoundry/_version.py @@ -3,4 +3,4 @@ """The LLM Foundry Version.""" -__version__ = '0.12.0.dev0' +__version__ = '0.13.0.dev0' diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index 70c4319ea8..73127e8a07 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -82,7 +82,7 @@ def evaluate_model( warnings.warn( VersionedDeprecationWarning( 'The argument fsdp_config is deprecated. Please use parallelism_config instead.', - remove_version='0.13.0', + remove_version='0.14.0', ), ) if fsdp_config and parallelism_config: diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index c8805e5d6d..f2b67db1ec 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -48,7 +48,7 @@ def __init__( warnings.warn( VersionedDeprecationWarning( '`HuggingFaceModelWithFSDP` is deprecated. In the future please use `BaseHuggingFaceModel`.', - remove_version='0.13.0', + remove_version='0.14.0', ), ) super().__init__( From c786defb6b6175243cd9e4a1b69918488ba7e3b9 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Wed, 25 Sep 2024 14:34:40 -0700 Subject: [PATCH 10/15] Add proper user error for accessing schema (#1548) Co-authored-by: v-chen_data --- .../data_prep/convert_delta_to_json.py | 24 ++++++++++++- .../data_prep/test_convert_delta_to_json.py | 35 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 666d0278c6..d676fc2165 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -233,7 +233,27 @@ def run_query( elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') - df = spark.sql(query) + + try: + df = spark.sql(query) + except Exception as e: + from pyspark.errors import AnalysisException + if isinstance(e, AnalysisException): + if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore + match = re.search( + r"Schema\s+'([^']+)'", + e.message, # pyright: ignore + ) + if match: + schema_name = match.group(1) + action = f'using the schema {schema_name}' + else: + action = 'using the schema' + raise InsufficientPermissionsError(action=action,) from e + raise RuntimeError( + f'Error in querying into schema. Restart sparkSession and try again', + ) from e + if collect: return df.collect() return df @@ -461,6 +481,8 @@ def fetch( raise InsufficientPermissionsError( action=f'reading from {tablename}', ) from e + if isinstance(e, InsufficientPermissionsError): + raise e raise RuntimeError( f'Error in get rows from {tablename}. Restart sparkSession and try again', ) from e diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index e623467bf7..bbb03a26d9 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -1,12 +1,14 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import sys import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, mock_open, patch from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + InsufficientPermissionsError, download, fetch_DT, format_tablename, @@ -17,6 +19,39 @@ class TestConvertDeltaToJsonl(unittest.TestCase): + def test_run_query_dbconnect_insufficient_permissions(self): + error_message = ( + '[INSUFFICIENT_PERMISSIONS] Insufficient privileges: User does not have USE SCHEMA ' + "on Schema 'main.oogabooga'. SQLSTATE: 42501" + ) + + class MockAnalysisException(Exception): + + def __init__(self, message: str): + self.message = message + + with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}): + sys.modules[ + 'pyspark.errors' + ].AnalysisException = MockAnalysisException # pyright: ignore + + mock_spark = MagicMock() + mock_spark.sql.side_effect = MockAnalysisException(error_message) + + with self.assertRaises(InsufficientPermissionsError) as context: + run_query( + 'SELECT * FROM table', + method='dbconnect', + cursor=None, + spark=mock_spark, + ) + + self.assertIn( + 'using the schema main.oogabooga', + str(context.exception), + ) + mock_spark.sql.assert_called_once_with('SELECT * FROM table') + @patch( 'databricks.sql.connect', ) From e6b8d142c3c8133f21b9e1d7c05927201976b2e8 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Wed, 25 Sep 2024 15:47:48 -0700 Subject: [PATCH 11/15] Validate Cluster Access Mode (#1551) Co-authored-by: v-chen_data Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- .../data_prep/convert_delta_to_json.py | 12 +++++++++++ llmfoundry/utils/exceptions.py | 13 ++++++++++++ .../data_prep/test_convert_delta_to_json.py | 20 +++++++++++++++---- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index d676fc2165..fbbc5f2cd9 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -20,6 +20,7 @@ from llmfoundry.utils.exceptions import ( ClusterDoesNotExistError, + ClusterInvalidAccessMode, FailedToConnectToDatabricksError, FailedToCreateSQLConnectionError, InsufficientPermissionsError, @@ -568,6 +569,17 @@ def validate_and_get_cluster_info( if res is None: raise ClusterDoesNotExistError(cluster_id) + data_security_mode = str( + res.data_security_mode, + ).upper()[len('DATASECURITYMODE.'):] + + # NONE stands for No Isolation Shared + if data_security_mode == 'NONE': + raise ClusterInvalidAccessMode( + cluster_id=cluster_id, + access_mode=data_security_mode, + ) + assert res.spark_version is not None stripped_runtime = re.sub( r'[a-zA-Z]', diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 900355dff5..265b9bbe8f 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -318,6 +318,19 @@ def __init__(self, cluster_id: str) -> None: super().__init__(message, cluster_id=cluster_id) +class ClusterInvalidAccessMode(NetworkError): + """Error thrown when the cluster does not exist.""" + + def __init__(self, cluster_id: str, access_mode: str) -> None: + message = f'Cluster with id {cluster_id} has access mode {access_mode}. ' + \ + 'please make sure the cluster used has access mode Shared or Single User!' + super().__init__( + message, + cluster_id=cluster_id, + access_mode=access_mode, + ) + + class FailedToCreateSQLConnectionError( NetworkError, ): diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index bbb03a26d9..b1a9f1e878 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -264,7 +264,10 @@ def test_dbconnect_called( DATABRICKS_TOKEN = 'token' use_serverless = False - mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') + mock_cluster_response = Namespace( + spark_version='14.1.0-scala2.12', + data_security_mode='SINGLE_USER', + ) mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response mock_remote = MagicMock() @@ -321,7 +324,10 @@ def test_sqlconnect_called_dbr13( DATABRICKS_TOKEN = 'token' use_serverless = False - mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') + mock_cluster_response = Namespace( + spark_version='13.0.0-scala2.12', + data_security_mode='SINGLE_USER', + ) mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response fetch_DT( @@ -373,7 +379,10 @@ def test_sqlconnect_called_dbr14( DATABRICKS_TOKEN = 'token' use_serverless = False - mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_cluster_response = Namespace( + spark_version='14.2.0-scala2.12', + data_security_mode='SINGLE_USER', + ) mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response fetch_DT( @@ -425,7 +434,10 @@ def test_sqlconnect_called_https( DATABRICKS_TOKEN = 'token' use_serverless = False - mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_cluster_response = Namespace( + spark_version='14.2.0-scala2.12', + data_security_mode='SINGLE_USER', + ) mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response fetch_DT( From dc58bb7eb95e52874774e1d5a7669a1a5f194429 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:40:13 -0700 Subject: [PATCH 12/15] Update mcli yamls (#1552) --- mcli/mcli-1b-eval.yaml | 4 ++-- mcli/mcli-1b-max-seq-len-8k.yaml | 4 ++-- mcli/mcli-1b.yaml | 4 ++-- mcli/mcli-benchmark-mpt.yaml | 4 ++-- mcli/mcli-convert-composer-to-hf.yaml | 4 ++-- mcli/mcli-hf-eval.yaml | 4 ++-- mcli/mcli-hf-generate.yaml | 4 ++-- mcli/mcli-llama2-finetune.yaml | 4 ++-- mcli/mcli-openai-eval.yaml | 4 ++-- mcli/mcli-pretokenize-oci-upload.yaml | 4 ++-- 10 files changed, 20 insertions(+), 20 deletions(-) diff --git a/mcli/mcli-1b-eval.yaml b/mcli/mcli-1b-eval.yaml index 4fcf8b3cb9..bd6a7b538a 100644 --- a/mcli/mcli-1b-eval.yaml +++ b/mcli/mcli-1b-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -9,7 +9,7 @@ integrations: command: | cd llm-foundry/scripts/ composer eval/eval.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest name: mpt-1b-eval compute: diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index fb96c576e0..b437bc5f0d 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -17,7 +17,7 @@ command: | --out_root ./my-copy-c4 --splits train_small val_small \ --concat_tokens 8192 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest name: mpt-1b-ctx-8k-gpus-8 compute: diff --git a/mcli/mcli-1b.yaml b/mcli/mcli-1b.yaml index 26255977f4..789fc4fc02 100644 --- a/mcli/mcli-1b.yaml +++ b/mcli/mcli-1b.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -21,7 +21,7 @@ command: | eval_loader.dataset.split=val_small \ max_duration=100ba \ eval_interval=0 -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest name: mpt-1b-gpus-8 compute: diff --git a/mcli/mcli-benchmark-mpt.yaml b/mcli/mcli-benchmark-mpt.yaml index 3995598fd3..0c023f9a83 100644 --- a/mcli/mcli-benchmark-mpt.yaml +++ b/mcli/mcli-benchmark-mpt.yaml @@ -6,12 +6,12 @@ compute: # cluster: TODO # Name of the cluster to use for this run # gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] diff --git a/mcli/mcli-convert-composer-to-hf.yaml b/mcli/mcli-convert-composer-to-hf.yaml index 7b715f6792..a211e3baeb 100644 --- a/mcli/mcli-convert-composer-to-hf.yaml +++ b/mcli/mcli-convert-composer-to-hf.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo @@ -13,7 +13,7 @@ command: | --hf_output_path s3://bucket/folder/hf/ \ --output_precision bf16 \ -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest name: convert-composer-hf compute: diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index 27f5938d67..9bcebfbea0 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -16,7 +16,7 @@ gpu_num: 8 # gpu_type: # cluster: # replace with your cluster here! -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: diff --git a/mcli/mcli-hf-generate.yaml b/mcli/mcli-hf-generate.yaml index cb3040e4ee..85a0f6b0e4 100644 --- a/mcli/mcli-hf-generate.yaml +++ b/mcli/mcli-hf-generate.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -35,7 +35,7 @@ command: | "Here's a quick recipe for baking chocolate chip cookies: Start by" \ "The best 5 cities to visit in Europe are" -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest name: hf-generate compute: diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 7134e6204c..210e8942b5 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo @@ -9,7 +9,7 @@ integrations: command: | cd llm-foundry/scripts composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest name: llama2-finetune compute: diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index cd04d89f4e..987fc829a9 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,7 +1,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: .[gpu,openai] ssh_clone: false # Should be true if using a private repo @@ -16,7 +16,7 @@ gpu_num: # gpu_type: # cluster: # replace with your cluster here! -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index 5425ce9897..49fbbb08d8 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -1,5 +1,5 @@ name: c4-2k-pre-tokenized -image: mosaicml/llm-foundry:2.3.1_cu121-latest +image: mosaicml/llm-foundry:2.4.0_cu124-latest compute: gpus: 8 # Number of GPUs to use @@ -14,7 +14,7 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.11.0 + git_branch: v0.12.0 # git_commit: # OR use your commit hash pip_install: . ssh_clone: false # Should be true if using a private repo From 3b1fc4ae5c205118901fcf1557260952fe844e2e Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:23:34 -0400 Subject: [PATCH 13/15] Use `allenai/c4` instead of `c4` dataset (#1554) Co-authored-by: Eitan Turok --- README.md | 2 +- TUTORIAL.md | 4 ++-- .../data_prep/convert_dataset_hf.py | 4 ++-- .../data_prep/convert_dataset_json.py | 2 +- mcli/mcli-1b-max-seq-len-8k.yaml | 2 +- mcli/mcli-1b.yaml | 2 +- mcli/mcli-pretokenize-oci-upload.yaml | 2 +- scripts/data_prep/README.md | 2 +- scripts/train/README.md | 6 ++--- .../train/benchmarking/submit_benchmarks.py | 2 +- .../data_prep/test_convert_dataset_hf.py | 2 +- tests/a_scripts/eval/test_eval.py | 11 +++++----- tests/a_scripts/train/test_train.py | 22 ++++++++++--------- tests/data/test_dataloader.py | 6 ++--- tests/data_utils.py | 2 +- 15 files changed, 37 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 0fabb98653..bc4eff48fd 100644 --- a/README.md +++ b/README.md @@ -223,7 +223,7 @@ cd scripts # Convert C4 dataset to StreamingDataset format python data_prep/convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root my-copy-c4 --splits train_small val_small \ --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' diff --git a/TUTORIAL.md b/TUTORIAL.md index 3be4910c4f..d1751f62e3 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -216,7 +216,7 @@ Output the processed data to `./my-adaptation-data`. Note that we use smaller su ```bash python scripts/data_prep/convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root my-adaptation-data --splits train_small val_small \ --concat_tokens 4096 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' \ --compression zstd @@ -248,7 +248,7 @@ The first step to training from scratch is to get your pretraining data prepared ```bash python scripts/data_prep/convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root my-copy-c4 --splits train_small val_small \ --concat_tokens 2048 --tokenizer gpt2 \ --eos_text '<|endoftext|>' \ diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_hf.py b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py index 0ea94ac687..2667407110 100644 --- a/llmfoundry/command_utils/data_prep/convert_dataset_hf.py +++ b/llmfoundry/command_utils/data_prep/convert_dataset_hf.py @@ -158,7 +158,7 @@ def __init__( truncated_samples=100, ) -CONSTS = {'c4': c4constants, 'the_pile': pileconstants} +CONSTS = {'allenai/c4': c4constants, 'the_pile': pileconstants} def build_hf_dataset( @@ -335,7 +335,7 @@ def convert_dataset_hf( dataset_constants = CONSTS[dataset] except KeyError: raise ValueError( - f'Constants for dataset "{dataset}" not found. Currently only "the_pile" and "c4" are supported.', + f'Constants for dataset "{dataset}" not found. Currently only "the_pile" and "allenai/c4" are supported.', ) if concat_tokens is not None and tokenizer is not None: diff --git a/llmfoundry/command_utils/data_prep/convert_dataset_json.py b/llmfoundry/command_utils/data_prep/convert_dataset_json.py index 35d7e637e6..c6f7d51c02 100644 --- a/llmfoundry/command_utils/data_prep/convert_dataset_json.py +++ b/llmfoundry/command_utils/data_prep/convert_dataset_json.py @@ -43,7 +43,7 @@ def build_hf_dataset( no_wrap (bool): if concatenating, whether to wrap text across `max_length` boundaries tokenizer (PreTrainedTokenizerBase): if mode is CONCAT_TOKENS, the tokenizer to use data_subset (str): Referred to as "name" in HuggingFace datasets.load_dataset. - Typically "all" (The Pile) or "en" (c4). + Typically "all" (The Pile) or "en" (allenai/c4). Returns: An IterableDataset. diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index b437bc5f0d..1d48cd8105 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -13,7 +13,7 @@ integrations: command: | cd llm-foundry/scripts python data_prep/convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root ./my-copy-c4 --splits train_small val_small \ --concat_tokens 8192 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml diff --git a/mcli/mcli-1b.yaml b/mcli/mcli-1b.yaml index 789fc4fc02..71566d4c46 100644 --- a/mcli/mcli-1b.yaml +++ b/mcli/mcli-1b.yaml @@ -13,7 +13,7 @@ integrations: command: | cd llm-foundry/scripts python data_prep/convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root ./my-copy-c4 --splits train_small val_small \ --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' composer train/train.py train/yamls/pretrain/mpt-1b.yaml \ diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index 49fbbb08d8..a3e8c40b88 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -24,7 +24,7 @@ command: | # Run the dataset conversion python convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root ./my-copy-c4 \ --splits val_small val train_small train \ --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' diff --git a/scripts/data_prep/README.md b/scripts/data_prep/README.md index 3601cc865f..b72caeebc4 100644 --- a/scripts/data_prep/README.md +++ b/scripts/data_prep/README.md @@ -14,7 +14,7 @@ Currently supports `c4` and `The Pile`. ```bash # Convert C4 dataset to StreamingDataset format python convert_dataset_hf.py \ - --dataset c4 --data_subset en \ + --dataset allenai/c4 --data_subset en \ --out_root my-copy-c4 --splits train_small val_small \ --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' \ --compression zstd diff --git a/scripts/train/README.md b/scripts/train/README.md index 6730cb793b..247814d782 100644 --- a/scripts/train/README.md +++ b/scripts/train/README.md @@ -27,7 +27,7 @@ If you haven't already, make sure to [install the requirements](../../README.md# To run pretraining, you'll need to make yourself a copy of a pretraining dataset and format it for efficient streaming. Check out the [`llm-foundry/data_prep`](../data_prep) folder for detailed instructions on how to convert your dataset to the MosaicML [StreamingDataset](https://github.com/mosaicml/streaming) format. -As a quickstart, we elaborate on how to prepare the [C4 (Colossal, Cleaned, Common Crawl)](https://huggingface.co/datasets/c4) dataset here. +As a quickstart, we elaborate on how to prepare the [C4 (Colossal, Cleaned, Common Crawl)](https://huggingface.co/datasets/allenai/c4) dataset here. We first convert the dataset from its native format (a collection of zipped JSONs) to MosaicML's StreamingDataset format, which is a collection of binary `.mds` files. @@ -44,13 +44,13 @@ This will take 20-60 seconds depending on your internet bandwidth. You should see two folders once completed: `./my-copy-c4/train_small` and `./my-copy-c4/val_small` that are ~1.0GB total. Note that we are using the `--concat_tokens` option to pre tokenize our samples to be of the max sequence length without padding ```bash -python ../data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' +python ../data_prep/convert_dataset_hf.py --dataset allenai/c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' ``` Alternatively, you can download the full `train` and `val` splits if you really want to train the model (i.e. not just profile the model). This will take 1-to-many hours depending on bandwidth, number of CPUs, etc. The final folder `./my-copy-c4/train` will be ~800GB so make sure you have space! ```bash -python ../data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train val --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' +python ../data_prep/convert_dataset_hf.py --dataset allenai/c4 --data_subset en --out_root ./my-copy-c4 --splits train val --concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' ``` For any of the above commands, you can also choose to compress the `.mds` files. diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index fd7be1fc6d..27f5c26c7d 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -479,7 +479,7 @@ def run_config( if args.data_remote is None: command += f""" cd llm-foundry/scripts - python data_prep/convert_dataset_hf.py --dataset c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --eos_text '<|endoftext|>' + python data_prep/convert_dataset_hf.py --dataset allenai/c4 --data_subset en --out_root ./my-copy-c4 --splits train_small val_small --concat_tokens {max_seq_len} --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml """ else: diff --git a/tests/a_scripts/data_prep/test_convert_dataset_hf.py b/tests/a_scripts/data_prep/test_convert_dataset_hf.py index e09c54ca70..da1e101ae7 100644 --- a/tests/a_scripts/data_prep/test_convert_dataset_hf.py +++ b/tests/a_scripts/data_prep/test_convert_dataset_hf.py @@ -11,7 +11,7 @@ def test_download_script_from_api(tmp_path: Path): # test calling it directly path = os.path.join(tmp_path, 'my-copy-c4-1') convert_dataset_hf( - dataset='c4', + dataset='allenai/c4', data_subset='en', splits=['val_xsmall'], out_root=path, diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index fc0dc8a882..f1b76913d1 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -121,7 +121,7 @@ def test_loader_eval( # Set up multiple eval dataloaders first_eval_loader = test_cfg.eval_loader - first_eval_loader.label = 'c4' + first_eval_loader.label = 'allenai/c4' # Create second eval dataloader using the arxiv dataset. second_eval_loader = copy.deepcopy(first_eval_loader) second_eval_loader.label = 'arxiv' @@ -157,16 +157,17 @@ def test_loader_eval( print(inmemorylogger.data.keys()) # Checks for first eval dataloader - assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert 'metrics/eval/allenai/c4/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'], list, ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], tuple, ) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 9af96f9868..b1bca9ebd0 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -134,7 +134,7 @@ def test_train_multi_eval(tmp_path: pathlib.Path): test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') # Set up multiple eval dataloaders first_eval_loader = test_cfg.eval_loader - first_eval_loader.label = 'c4' + first_eval_loader.label = 'allenai/c4' # Create second eval dataloader using the arxiv dataset. second_eval_loader = copy.deepcopy(first_eval_loader) second_eval_loader.label = 'arxiv' @@ -154,16 +154,17 @@ def test_train_multi_eval(tmp_path: pathlib.Path): assert isinstance(inmemorylogger, InMemoryLogger) # Checks for first eval dataloader - assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert 'metrics/eval/allenai/c4/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'], list, ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], tuple, ) @@ -212,7 +213,7 @@ def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') first_eval_loader = test_cfg.eval_loader - first_eval_loader.label = 'c4' + first_eval_loader.label = 'allenai/c4' test_cfg.eval_loader = om.create([first_eval_loader]) test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches test_cfg.max_duration = '1ba' @@ -226,15 +227,16 @@ def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): 0] # pyright: ignore [reportGeneralTypeIssues] assert isinstance(inmemorylogger, InMemoryLogger) - assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert 'metrics/eval/allenai/c4/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'], list, ) assert len( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], ) > 0 assert isinstance( - inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], + inmemorylogger.data['metrics/eval/allenai/c4/LanguageCrossEntropy'][-1], tuple, ) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index d215d93542..7239bfe958 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -204,7 +204,7 @@ def test_correct_padding( shutil.rmtree(path, ignore_errors=True) if pretokenize: convert_dataset_hf( - dataset='c4', + dataset='allenai/c4', data_subset='en', splits=[split], out_root=path, @@ -219,7 +219,7 @@ def test_correct_padding( ) else: convert_dataset_hf( - dataset='c4', + dataset='allenai/c4', data_subset='en', splits=[split], out_root=path, @@ -233,7 +233,7 @@ def test_correct_padding( num_workers=None, ) if not os.path.isdir(path): - raise RuntimeError(f'c4 dataset at {path} not set up as expected') + raise RuntimeError(f'allenai/c4 dataset at {path} not set up as expected') test_cfg = get_config( conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml', diff --git a/tests/data_utils.py b/tests/data_utils.py index 117310b0cf..1f6c26b72e 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -231,7 +231,7 @@ def create_c4_dataset_xxsmall(path: Path) -> str: # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 convert_dataset_hf( - dataset='c4', + dataset='allenai/c4', data_subset='en', splits=[downloaded_split], out_root=c4_dir, From ee456002a1dd86f3d9102ac5ade9f7436be51d82 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 27 Sep 2024 10:39:39 -0400 Subject: [PATCH 14/15] Tensor Parallelism (#1521) Co-authored-by: Eitan Turok Co-authored-by: Mihir Patel --- llmfoundry/__init__.py | 2 + llmfoundry/command_utils/train.py | 32 +++++-- llmfoundry/registry.py | 22 +++++ llmfoundry/tp/__init__.py | 11 +++ llmfoundry/tp/ffn_tp_strategy.py | 56 +++++++++++++ llmfoundry/utils/builders.py | 29 +++++-- llmfoundry/utils/config_utils.py | 14 +++- tests/test_registry.py | 1 + tests/tp/__init__.py | 2 + tests/tp/test_tp_strategies.py | 133 ++++++++++++++++++++++++++++++ 10 files changed, 289 insertions(+), 13 deletions(-) create mode 100644 llmfoundry/tp/__init__.py create mode 100644 llmfoundry/tp/ffn_tp_strategy.py create mode 100644 tests/tp/__init__.py create mode 100644 tests/tp/test_tp_strategies.py diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index b851aaa559..07e8f35747 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -48,6 +48,7 @@ models, optim, tokenizers, + tp, utils, ) from llmfoundry._version import __version__ @@ -87,5 +88,6 @@ 'models', 'optim', 'tokenizers', + 'tp', 'utils', ] diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 14b7980d57..29878714f6 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -5,6 +5,7 @@ import os import time import warnings +from copy import deepcopy from typing import Any, Optional, Union import torch @@ -43,6 +44,7 @@ build_save_planner, build_scheduler, build_tokenizer, + build_tp_strategies, ) from llmfoundry.utils.config_utils import ( TRAIN_CONFIG_KEYS, @@ -329,16 +331,27 @@ def train(cfg: DictConfig) -> Trainer: changing autoresume default to True...', ) - # Warn if fsdp is enabled but user only has 1 GPU - if dist.get_world_size() == 1 and fsdp_config is not None: + # Optional tp config + tp_config: Optional[dict[str, Any]] = train_cfg.tp_config + + # Warn if FSDP or TP is enabled but user only has 1 GPU + if dist.get_world_size( + ) == 1 and (fsdp_config is not None or tp_config is not None): + parallelism = '' + if fsdp_config is not None: + parallelism += 'FSDP' + if tp_config is not None: + parallelism += '+TP' if fsdp_config is not None else 'TP' warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.', + f'{parallelism} is not applicable for single-GPU training. Reverting to DDP.', ) fsdp_config = None + tp_config = None # Initialize context - init_context = process_init_device(model_config, fsdp_config) + init_context = process_init_device(model_config, fsdp_config, tp_config) logged_cfg.update({'fsdp_config': fsdp_config}, merge=True) + logged_cfg.update({'tp_config': deepcopy(tp_config)}, merge=True) # Build tokenizer log.info('Building tokenizer...') @@ -502,6 +515,15 @@ def train(cfg: DictConfig) -> Trainer: _log_num_params(model, logged_cfg) + # TP config + if tp_config is not None: + strategy = tp_config.pop('strategy', None) + assert isinstance(strategy, str), '`strategy` must be in `tp_config`.' + tp_config['layer_plan'] = build_tp_strategies(strategy, model) + + # Parallelism config + parallelism_config = {'fsdp': fsdp_config, 'tp': tp_config} + # Optimizer optimizer_name: str = train_cfg.optimizer.pop('name') optimizer_cfg = train_cfg.optimizer @@ -546,7 +568,7 @@ def train(cfg: DictConfig) -> Trainer: precision=train_cfg.precision, algorithms=algorithms, device_train_microbatch_size=train_cfg.device_train_microbatch_size, - parallelism_config={'fsdp': fsdp_config}, + parallelism_config=parallelism_config, save_folder=train_cfg.save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index cb2455a760..850c4f3bbd 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -7,6 +7,7 @@ from composer.models import ComposerModel from composer.optim import ComposerScheduler from torch.distributed.checkpoint import LoadPlanner, SavePlanner +from torch.distributed.tensor.parallel.style import ParallelStyle from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader from torch.utils.data import Dataset @@ -389,6 +390,26 @@ description=_save_planners_description, ) +_tp_strategies_description = ( + """The tp_strategies registry is used to register strategies for tensor parallelism. + + Args: + model (ComposerModel): The model. + + Returns: + layer_plan (Dict[str, ParallelStyle]): The plan used to parallelize the model. + model (ComposerModel): The model. + """ +) + +tp_strategies = create_registry( + 'llmfoundry', + 'tp_strategies', + generic_type=Callable[[ComposerModel], dict[str, ParallelStyle]], + entry_points=True, + description=_tp_strategies_description, +) + __all__ = [ 'loggers', 'callbacks', @@ -416,4 +437,5 @@ 'config_transforms', 'load_planners', 'save_planners', + 'tp_strategies', ] diff --git a/llmfoundry/tp/__init__.py b/llmfoundry/tp/__init__.py new file mode 100644 index 0000000000..323ae23727 --- /dev/null +++ b/llmfoundry/tp/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.registry import tp_strategies +from llmfoundry.tp.ffn_tp_strategy import ffn_tp_strategy + +tp_strategies.register('ffn', func=ffn_tp_strategy) + +__all__ = [ + 'ffn_tp_strategy', +] diff --git a/llmfoundry/tp/ffn_tp_strategy.py b/llmfoundry/tp/ffn_tp_strategy.py new file mode 100644 index 0000000000..1de92ef6ae --- /dev/null +++ b/llmfoundry/tp/ffn_tp_strategy.py @@ -0,0 +1,56 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from composer.models import ComposerModel +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) +from torch.distributed.tensor.parallel.style import ParallelStyle + + +def ffn_tp_strategy(model: ComposerModel) -> dict[str, ParallelStyle]: + TP_LAYERS = {'ffn', 'ffn.up_proj', 'ffn.down_proj'} + + # Validate that all TP_LAYERS are in model + tp_layers_in_model = { + layer for layer in TP_LAYERS for name, _ in model.named_modules() + if layer in name + } + if tp_layers_in_model != TP_LAYERS: + raise RuntimeError( + f'The FFN tensor parallelism strategy requires `model` to have layers {TP_LAYERS}. But `model` is missing layers {TP_LAYERS - tp_layers_in_model}.', + ) + + # Generate layer plan + layer_plan: dict[str, ParallelStyle] = {} + for name, _ in model.named_modules(): + # Before the ffn layer starts, distribute the input data for proper TP use + # Inputs are currently sharded across the batch dimension (dim 0) as is done in standard DDP + # Inputs will be replicated across hidden dimension (dim 1) via allgather + if name.split('.')[-1] == 'ffn': + layer_plan[name] = PrepareModuleInput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + use_local_output=True, + ) + # Shard the ffn.up_proj weight matrix across its columns + # Inputs are already replicated across each TP group + # Outputs will be sharded along the hidden dimension (dim 1) via allgather + elif name.split('.')[-2:] == ['ffn', 'up_proj']: + layer_plan[name] = ColwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(-1), + ) + # Shard the ffn.down_proj weight matrix across its rows + # Inputs are sharded along the hidden dimension (dim 1) + # Outputs will be sharded along batch dimension (dim 0) via allreduce + elif name.split('.')[-2:] == ['ffn', 'down_proj']: + layer_plan[name] = RowwiseParallel( + input_layouts=Shard(-1), + output_layouts=Shard(0), + ) + + return layer_plan diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f2d5cfc0f7..687b21b46d 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -7,14 +7,9 @@ import logging import os import re +import warnings from collections import OrderedDict -from typing import ( - Any, - ContextManager, - Iterable, - Optional, - Union, -) +from typing import Any, ContextManager, Iterable, Optional, Union import torch from composer.core import Algorithm, Callback, Evaluator @@ -25,6 +20,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.distributed.checkpoint import LoadPlanner, SavePlanner +from torch.distributed.tensor.parallel.style import ParallelStyle from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -37,6 +33,7 @@ ) from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry +from llmfoundry.utils.warnings import experimental_function log = logging.getLogger(__name__) @@ -52,6 +49,7 @@ 'build_tokenizer', 'build_composer_model', 'build_metric', + 'build_tp_strategies', ] @@ -701,3 +699,20 @@ def _validate_cfg(icl_cfg: dict[str, Any]): ) return evaluators, logger_keys + + +@experimental_function('Tensor Parallelism') +def build_tp_strategies( + name: str, + model: ComposerModel, +) -> dict[str, ParallelStyle]: + + warnings.warn( + 'Checkpointing is not currently supported for tensor parallelism due to this pytorch bug: https://github.com/pytorch/pytorch/issues/134095#issuecomment-2345018244', + ) + return construct_from_registry( + name=name, + registry=registry.tp_strategies, + partial_function=False, + kwargs={'model': model}, + ) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index ba5c5941b8..c22495993c 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -120,6 +120,7 @@ class TrainConfig: # Distributed training parameters dist_timeout: Union[int, float] = 600.0 fsdp_config: Optional[dict[str, Any]] = None + tp_config: Optional[dict[str, Any]] = None # Evaluation parameters eval_interval: Union[int, str] = 1 @@ -501,7 +502,11 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]: return cfg -def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]): +def process_init_device( + model_cfg: dict[str, Any], + fsdp_config: Optional[dict] = None, + tp_config: Optional[dict] = None, +): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors # when multiple GPUs are available. @@ -533,6 +538,13 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]): # Set defaults for mixed initialization fsdp_config.setdefault('load_monolith_rank0_only', True) + # Check we are not using tensor parallelism with MoEs + if tp_config is not None and 'ffn_config' in model_cfg and model_cfg[ + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: + raise ValueError( + 'Tensor Parallelism is not currently supported for MoE models.', + ) + # Set ffn_config.device_mesh using fsdp_config if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg[ 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: diff --git a/tests/test_registry.py b/tests/test_registry.py index 5108a7d46c..90ef3bfaac 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -47,6 +47,7 @@ def test_expected_registries_exist(): 'config_transforms', 'load_planners', 'save_planners', + 'tp_strategies', } assert existing_registries == expected_registry_names diff --git a/tests/tp/__init__.py b/tests/tp/__init__.py new file mode 100644 index 0000000000..80950cb7b4 --- /dev/null +++ b/tests/tp/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/tp/test_tp_strategies.py b/tests/tp/test_tp_strategies.py new file mode 100644 index 0000000000..fd2fa384ce --- /dev/null +++ b/tests/tp/test_tp_strategies.py @@ -0,0 +1,133 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +from omegaconf import OmegaConf as om +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, +) + +from llmfoundry.command_utils.train import train +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM +from llmfoundry.utils.builders import build_tp_strategies +from llmfoundry.utils.config_utils import process_init_device +from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg + + +@pytest.mark.gpu +@pytest.mark.filterwarnings( + 'ignore:tp_strategies is experimental and may change with future versions.', +) +def test_ffn_tp_strategy(): + """Test the FFN tensor parallelism strategy is correct.""" + # Create layer plan from fnn tp_strategy + tp_config = { + 'strategy': 'ffn', + } + + model_cfg = { + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 3, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + } + model = ComposerMPTCausalLM(**model_cfg) + layer_plan = build_tp_strategies(tp_config['strategy'], model) + + # Expected layer plan + _expected_layer_plan = { + 'ffn': + PrepareModuleInput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + use_local_output=True, + ), + 'ffn.down_proj': + RowwiseParallel( + input_layouts=Shard(-1), + output_layouts=Shard(0), + ), + 'ffn.up_proj': + ColwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(-1), + ), + } + expected_layer_plan = { + f'model.transformer.blocks.{layer_idx}.{name}': layer_plan + for name, layer_plan in _expected_layer_plan.items() + for layer_idx in range(model_cfg['n_layers']) + } + + # Compare expected and actual layer plans + for (n1, lp1), (n2, lp2) in zip( + sorted(expected_layer_plan.items()), + sorted(layer_plan.items()), + ): + assert n1 == n2 + assert type(lp1) == type(lp2) + if isinstance( + lp1, + PrepareModuleInput, + ) and isinstance(lp2, PrepareModuleInput): + assert lp1.input_layouts == lp2.input_layouts + assert lp1.desired_input_layouts == lp2.desired_input_layouts + assert lp1.use_local_output == lp2.use_local_output + elif ( + isinstance(lp1, ColwiseParallel) and + isinstance(lp2, ColwiseParallel) + ) or ( + isinstance(lp1, RowwiseParallel) and + isinstance(lp2, RowwiseParallel) + ): + assert lp1.input_layouts == lp2.input_layouts + assert lp1.output_layouts == lp2.output_layouts + assert lp1.use_local_output == lp2.use_local_output + else: + raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}') + + +@pytest.mark.gpu +def test_no_tp_with_one_gpu(): + """Test that when we have one GPU, we use DDP and not FSDP-TP.""" + with TemporaryDirectory() as tmp_path: + # Make `train_cfg`` with a tensor parallelism strategy + dataset_name = create_c4_dataset_xxsmall(Path(tmp_path)) + train_cfg = gpt_tiny_cfg(dataset_name, 'gpu') + train_cfg.tp_config = {'strategy': 'ffn'} + + # Expect a warning + with pytest.warns( + UserWarning, + match= + r'FSDP\+TP is not applicable for single-GPU training. Reverting to DDP.', + ): + train(train_cfg) + + +@pytest.mark.gpu # use gpu because `megablocks` only installed with `gpu` dependencies +def test_no_tp_with_moes(): + """Test that tensor parallelism is not compatible with MoEs.""" + # Make `cfg` for MoE model, fsdp, and tp + train_cfg_path: str = 'scripts/train/yamls/pretrain/testing-moe.yaml' + with open(train_cfg_path, 'r', encoding='utf-8') as f: + train_cfg = om.load(f) + model_cfg = train_cfg.model + fsdp_cfg = train_cfg.fsdp_config + tp_cfg = {'strategy': 'ffn'} + + # Expect an error + with pytest.raises( + ValueError, + match='Tensor Parallelism is not currently supported for MoE models.', + ): + process_init_device(model_cfg, fsdp_cfg, tp_cfg) From 107d246a4c9c04f0a906f8f0fafcca1297d9e68e Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Fri, 27 Sep 2024 13:12:00 -0700 Subject: [PATCH 15/15] Insufficient Permissions Error when trying to access table (#1555) Co-authored-by: v-chen_data --- .../data_prep/convert_delta_to_json.py | 127 +++++++----------- llmfoundry/utils/exceptions.py | 13 +- .../data_prep/test_convert_delta_to_json.py | 23 ++-- tests/utils/test_exceptions.py | 39 ++++-- 4 files changed, 103 insertions(+), 99 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index fbbc5f2cd9..44e8651cdf 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -234,27 +234,7 @@ def run_query( elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') - - try: - df = spark.sql(query) - except Exception as e: - from pyspark.errors import AnalysisException - if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - match = re.search( - r"Schema\s+'([^']+)'", - e.message, # pyright: ignore - ) - if match: - schema_name = match.group(1) - action = f'using the schema {schema_name}' - else: - action = 'using the schema' - raise InsufficientPermissionsError(action=action,) from e - raise RuntimeError( - f'Error in querying into schema. Restart sparkSession and try again', - ) from e - + df = spark.sql(query) if collect: return df.collect() return df @@ -469,71 +449,66 @@ def fetch( """ cursor = dbsql.cursor() if dbsql is not None else None try: - nrows = get_total_rows( - tablename, - method, - cursor, - sparkSession, - ) - except Exception as e: - from pyspark.errors import AnalysisException - if isinstance(e, AnalysisException): - if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore - raise InsufficientPermissionsError( - action=f'reading from {tablename}', - ) from e - if isinstance(e, InsufficientPermissionsError): - raise e - raise RuntimeError( - f'Error in get rows from {tablename}. Restart sparkSession and try again', - ) from e + # Get total rows + nrows = get_total_rows(tablename, method, cursor, sparkSession) - try: + # Get columns info columns, order_by, columns_str = get_columns_info( tablename, method, cursor, sparkSession, ) + + if method == 'dbconnect' and sparkSession is not None: + log.info(f'{processes=}') + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_folder, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data( + method, + cursor, + sparkSession, + start, + end, + order_by, + tablename, + columns_str, + json_output_folder, + ) + except Exception as e: - raise RuntimeError( - f'Error in get columns from {tablename}. Restart sparkSession and try again', - ) from e + from databricks.sql.exc import ServerOperationError + from pyspark.errors import AnalysisException - if method == 'dbconnect' and sparkSession is not None: - log.info(f'{processes=}') - df = sparkSession.table(tablename) - - # Running the query and collecting the data as arrow or json. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore - log.info(f'len(signed) = {len(signed)}') - - args = get_args(signed, json_output_folder, columns) - - # Stopping the SparkSession to avoid spilling connection state into the subprocesses. - sparkSession.stop() - - with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_starargs, args)) - - elif method == 'dbsql' and cursor is not None: - for start in range(0, nrows, batch_size): - log.warning(f'batch {start}') - end = min(start + batch_size, nrows) - fetch_data( - method, - cursor, - sparkSession, - start, - end, - order_by, - tablename, - columns_str, - json_output_folder, - ) + if isinstance(e, (AnalysisException, ServerOperationError)): + if 'INSUFFICIENT_PERMISSIONS' in str(e): + raise InsufficientPermissionsError(str(e)) from e + + if isinstance(e, InsufficientPermissionsError): + raise + + # For any other exception, raise a general error + raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e - if cursor is not None: - cursor.close() + finally: + if cursor is not None: + cursor.close() def validate_and_get_cluster_info( diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 265b9bbe8f..242ac4f32c 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -456,6 +456,13 @@ def __init__( class InsufficientPermissionsError(UserError): """Error thrown when the user does not have sufficient permissions.""" - def __init__(self, action: str) -> None: - message = f'Insufficient permissions when {action}. Please check your permissions.' - super().__init__(message, action=action) + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + def __reduce__(self): + # Return a tuple of class, a tuple of arguments, and optionally state + return (InsufficientPermissionsError, (self.message,)) + + def __str__(self): + return self.message diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index b1a9f1e878..981f5c1ed6 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -10,6 +10,7 @@ from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( InsufficientPermissionsError, download, + fetch, fetch_DT, format_tablename, iterative_combine_jsons, @@ -30,27 +31,33 @@ class MockAnalysisException(Exception): def __init__(self, message: str): self.message = message + def __str__(self): + return self.message + with patch.dict('sys.modules', {'pyspark.errors': MagicMock()}): sys.modules[ 'pyspark.errors' - ].AnalysisException = MockAnalysisException # pyright: ignore + ].AnalysisException = MockAnalysisException # type: ignore mock_spark = MagicMock() mock_spark.sql.side_effect = MockAnalysisException(error_message) with self.assertRaises(InsufficientPermissionsError) as context: - run_query( - 'SELECT * FROM table', + fetch( method='dbconnect', - cursor=None, - spark=mock_spark, + tablename='main.oogabooga', + json_output_folder='/fake/path', + batch_size=1, + processes=1, + sparkSession=mock_spark, + dbsql=None, ) - self.assertIn( - 'using the schema main.oogabooga', + self.assertEqual( str(context.exception), + error_message, ) - mock_spark.sql.assert_called_once_with('SELECT * FROM table') + mock_spark.sql.assert_called() @patch( 'databricks.sql.connect', diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index 8bfc7287ab..564dfa2f14 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -4,7 +4,7 @@ import contextlib import inspect import pickle -from typing import Any, Optional +from typing import Any, Optional, get_type_hints import pytest @@ -14,16 +14,30 @@ def create_exception_object( exception_class: type[foundry_exceptions.BaseContextualError], ): - # get required arg types of exception class by inspecting its __init__ method - if hasattr(inspect, 'get_annotations'): - required_args = inspect.get_annotations( # type: ignore - exception_class.__init__, - ) # type: ignore - else: - required_args = exception_class.__init__.__annotations__ # python 3.9 and below - - # create a dictionary of required args with default values + def get_init_annotations(cls: type): + try: + return get_type_hints(cls.__init__) + except (AttributeError, TypeError): + # Handle cases where __init__ does not exist or has no annotations + return {} + + # First, try to get annotations from the class itself + required_args = get_init_annotations(exception_class) + + # If the annotations are empty, look at parent classes + if not required_args: + for parent in exception_class.__bases__: + if parent == object: + break + parent_args = get_init_annotations(parent) + if parent_args: + required_args = parent_args + break + + # Remove self, return, and kwargs + required_args.pop('self', None) + required_args.pop('return', None) required_args.pop('kwargs', None) def get_default_value(arg_type: Optional[type] = None): @@ -51,8 +65,6 @@ def get_default_value(arg_type: Optional[type] = None): return [{'key': 'value'}] raise ValueError(f'Unsupported arg type: {arg_type}') - required_args.pop('self', None) - required_args.pop('return', None) kwargs = { arg: get_default_value(arg_type) for arg, arg_type in required_args.items() @@ -80,6 +92,7 @@ def filter_exceptions(possible_exceptions: list[str]): def test_exception_serialization( exception_class: type[foundry_exceptions.BaseContextualError], ): + print(f'Testing serialization for {exception_class.__name__}') excluded_base_classes = [ foundry_exceptions.InternalError, foundry_exceptions.UserError, @@ -88,6 +101,7 @@ def test_exception_serialization( ] exception = create_exception_object(exception_class) + print(f'Created exception object: {exception}') expect_reduce_error = exception.__class__ in excluded_base_classes error_context = pytest.raises( @@ -95,6 +109,7 @@ def test_exception_serialization( ) if expect_reduce_error else contextlib.nullcontext() exc_str = str(exception) + print(f'Exception string: {exc_str}') with error_context: pkl = pickle.dumps(exception) unpickled_exc = pickle.loads(pkl)