diff --git a/README.md b/README.md index 7ce58c772f..ef2a754658 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ Tutorial videos from the community: Something missing? Contribute with a PR! # Latest News +* [Blog: Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) * [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators) * [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus) * [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250) @@ -305,7 +306,7 @@ dependencies = [ "llm-foundry", ] -[project.entry-points."llm_foundry.loggers"] +[project.entry-points."llmfoundry_loggers"] my_logger = "foundry_registry.loggers:MyLogger" ``` diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 9eb23f1030..b7d80bd5f8 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -7,9 +7,12 @@ import math import os import re +import shutil import tempfile +import time +from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -72,6 +75,31 @@ def _maybe_get_license_filename( return None +def _register_model_with_run_id_multiprocess( + mlflow_logger: MLFlowLogger, + composer_logging_level: int, + model_uri: str, + name: str, + await_creation_for: int, +): + """Call MLFlowLogger.register_model_with_run_id. + + Used mainly to register from a child process. + """ + # Setup logging for child process. This ensures that any logs from composer are surfaced. + if composer_logging_level > 0: + # If logging_level is 0, then the composer logger was unset. + logging.basicConfig( + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('composer').setLevel(composer_logging_level) + + # Register model. + mlflow_logger.register_model_with_run_id( + model_uri=model_uri, name=name, await_creation_for=await_creation_for) + + class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. @@ -171,6 +199,10 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] + self.child_processes: List[SpawnProcess] = [] + # Temporary save directory used by child_processes. + self.temp_save_dir = None + def run_event(self, event: Event, state: State, logger: Logger) -> None: # The interval scheduler handles only returning True for the appropriate events if state.get_elapsed_duration() is not None and self.check_interval( @@ -201,7 +233,22 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: import mlflow mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( - '5GB') + '1GB') + elif event == Event.FIT_END: + # Wait for all child processes spawned by the callback to finish. + timeout = 3600 + wait_start = time.time() + while not self._all_child_processes_done(): + wait_time = time.time() - wait_start + if wait_time > timeout: + raise TimeoutError( + f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.' + ) + time.sleep(2) + + # Clean up temporary save directory; all processes are done with it. + if self.temp_save_dir is not None: + shutil.rmtree(self.temp_save_dir) def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() @@ -218,6 +265,12 @@ def _is_last_batch(self, state: State): return False + def _all_child_processes_done(self) -> bool: + not_done = any(process.is_alive() for process in self.child_processes) + x = torch.tensor(1 if not_done else 0).to(device='cuda') + dist.all_reduce(x, reduce_operation='MAX') + return x.item() == 0 + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -235,158 +288,175 @@ def _save_checkpoint(self, state: State, logger: Logger): Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp) - dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.remote_ud is not None else contextlib.nullcontext( - enter_result=save_dir) - - with dir_context_mgr as temp_save_dir: - assert isinstance(temp_save_dir, - str) # pyright doesn't know about enter_result - - log.debug('Gathering state dict') - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - if state.is_model_ddp: - composer_model = state.model.module - original_model: PreTrainedModel = state.model.module.model - state_dict_model = state.model.module.model - original_tokenizer = state.model.module.tokenizer - elif isinstance(state.model.model, FSDP): - composer_model = state.model - original_model: PreTrainedModel = state.model.model.module - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer + + # Use a temporary directory if save_dir is remote. + use_temp_dir = self.remote_ud is not None + temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir + + log.debug('Gathering state dict') + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + if state.is_model_ddp: + composer_model = state.model.module + original_model: PreTrainedModel = state.model.module.model + state_dict_model = state.model.module.model + original_tokenizer = state.model.module.tokenizer + elif isinstance(state.model.model, FSDP): + composer_model = state.model + original_model: PreTrainedModel = state.model.model.module + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + else: + composer_model = state.model + original_model: PreTrainedModel = state.model.model + state_dict_model = state.model.model + original_tokenizer = state.model.tokenizer + + state_dict_context = fsdp_state_dict_type_context( + original_model, + state_dict_type='full') if ((not state.is_model_ddp) and isinstance( + state_dict_model, FSDP)) else contextlib.nullcontext() + + with state_dict_context: + state_dict = state_dict_model.state_dict() + + # convert the state dict to the requested precision + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(dtype=self.dtype) + + new_model_instance = None # Need this for pyright because variable could be unbound + + if dist.get_global_rank() == 0: + log.debug('Saving Hugging Face checkpoint in global rank 0') + + copied_config = copy.deepcopy(original_model.config) + if copied_config.model_type == 'mpt': + copied_config.attn_config['attn_impl'] = 'torch' + copied_config.init_device = 'cpu' + + log.debug(f'Creating new model instance') + + if composer_model.using_peft: + # We don't use meta here because the state dict does not contain the full + # model, only the adapter weights. + active_adapter = original_model.active_adapter + base_model = original_model.get_base_model() + new_base_model_instance = type(base_model)(copied_config) + + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config[active_adapter]) + new_model_instance.to(dtype=self.dtype) else: - composer_model = state.model - original_model: PreTrainedModel = state.model.model - state_dict_model = state.model.model - original_tokenizer = state.model.tokenizer - - state_dict_context = fsdp_state_dict_type_context( - original_model, state_dict_type='full') if ( - (not state.is_model_ddp) and isinstance( - state_dict_model, FSDP)) else contextlib.nullcontext() - - with state_dict_context: - state_dict = state_dict_model.state_dict() - - # convert the state dict to the requested precision - for k, v in state_dict.items(): - if isinstance(v, torch.Tensor): - state_dict[k] = v.to(dtype=self.dtype) - - if dist.get_global_rank() == 0: - log.debug('Saving Hugging Face checkpoint in global rank 0') - - copied_config = copy.deepcopy(original_model.config) - if copied_config.model_type == 'mpt': - copied_config.attn_config['attn_impl'] = 'torch' - copied_config.init_device = 'cpu' - - log.debug(f'Creating new model instance') - - if composer_model.using_peft: - # We don't use meta here because the state dict does not contain the full - # model, only the adapter weights. - active_adapter = original_model.active_adapter - base_model = original_model.get_base_model() - new_base_model_instance = type(base_model)(copied_config) - - new_model_instance = type(original_model)( - new_base_model_instance, - original_model.peft_config[active_adapter]) - new_model_instance.to(dtype=self.dtype) - else: - # First create the model instance on meta device to avoid the - # initialization cost. - with init_empty_weights(): - new_model_instance = type(original_model)(copied_config) - - # Then load the state dict in with "assign" so that the state dict - # is loaded properly even though the model is initially on meta device. - new_model_instance.load_state_dict(state_dict, assign=True) - del state_dict - - log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained(temp_save_dir) - if original_tokenizer is not None: - assert isinstance(original_tokenizer, - PreTrainedTokenizerBase) - original_tokenizer.save_pretrained(temp_save_dir) - - # Only need to edit files for MPT because it has custom code - if original_model.config.model_type == 'mpt': - log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility( - temp_save_dir, - self.flatten_imports, + # First create the model instance on meta device to avoid the + # initialization cost. + with init_empty_weights(): + new_model_instance = type(original_model)(copied_config) + + # Then load the state dict in with "assign" so that the state dict + # is loaded properly even though the model is initially on meta device. + new_model_instance.load_state_dict(state_dict, assign=True) + del state_dict + + log.debug('Saving Hugging Face checkpoint to disk') + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + original_tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if original_model.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, + ) + + if self.remote_ud is not None: + for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' + ) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path(os.path.join(temp_save_dir, filename)), + overwrite=self.overwrite, ) - if self.remote_ud is not None: - for filename in os.listdir(temp_save_dir): - remote_file_name = os.path.join(save_dir, filename) - remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name) - log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}' - ) - self.remote_ud.upload_file( - state=state, - remote_file_name=remote_file_name, - file_path=Path(os.path.join(temp_save_dir, - filename)), - overwrite=self.overwrite, - ) + dist.barrier() - if self.mlflow_registered_model_name and self._is_last_batch( - state): - components = {'model': new_model_instance} - if original_tokenizer is not None: - components['tokenizer'] = original_tokenizer + if dist.get_global_rank() == 0: + if self.mlflow_registered_model_name and self._is_last_batch(state): + components = {'model': new_model_instance} + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer - log.debug('Logging Hugging Face model to MLFlow') - for i, mlflow_logger in enumerate(self.mlflow_loggers): - log.debug( - f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' - ) - local_save_path = str( - Path(temp_save_dir) / f'mlflow_save_{i}') - - # TODO: Remove after mlflow fixes the bug that makes this necessary - import mlflow - mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - model_saving_kwargs: Dict[str, Any] = { - 'path': local_save_path - } - if composer_model.using_peft: - model_saving_kwargs['flavor'] = 'peft' - model_saving_kwargs[ - 'save_pretrained_dir'] = temp_save_dir - model_saving_kwargs[ - 'metadata'] = self.mlflow_logging_config[ - 'metadata'] - else: - model_saving_kwargs['flavor'] = 'transformers' - model_saving_kwargs[ - 'transformers_model'] = components - model_saving_kwargs.update( - self.mlflow_logging_config) - - mlflow_logger.save_model(**model_saving_kwargs) - - # Upload the license file generated by mlflow during the model saving. - license_filename = _maybe_get_license_filename( - local_save_path, - self.mlflow_logging_config['metadata'].get( - 'pretrained_model_name', None)) - if license_filename is not None: - mlflow_logger._mlflow_client.log_artifact( - mlflow_logger._run_id, - os.path.join(local_save_path, license_filename), - ) - - mlflow_logger.register_model_with_run_id( - model_uri=local_save_path, - name=self.mlflow_registered_model_name, - await_creation_for=3600, + log.debug('Logging Hugging Face model to MLFlow') + for i, mlflow_logger in enumerate(self.mlflow_loggers): + log.debug( + f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}' + ) + local_save_path = str( + Path(temp_save_dir) / f'mlflow_save_{i}') + + # TODO: Remove after mlflow fixes the bug that makes this necessary + import mlflow + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + model_saving_kwargs: Dict[str, Any] = { + 'path': local_save_path + } + if composer_model.using_peft: + model_saving_kwargs['flavor'] = 'peft' + model_saving_kwargs[ + 'save_pretrained_dir'] = temp_save_dir + model_saving_kwargs[ + 'metadata'] = self.mlflow_logging_config['metadata'] + else: + model_saving_kwargs['flavor'] = 'transformers' + model_saving_kwargs['transformers_model'] = components + model_saving_kwargs.update(self.mlflow_logging_config) + + mlflow_logger.save_model(**model_saving_kwargs) + + # Upload the license file generated by mlflow during the model saving. + license_filename = _maybe_get_license_filename( + local_save_path, + self.mlflow_logging_config['metadata'].get( + 'pretrained_model_name', None)) + if license_filename is not None: + mlflow_logger._mlflow_client.log_artifact( + mlflow_logger._run_id, + os.path.join(local_save_path, license_filename), ) + + # Spawn a new process to register the model. + process = SpawnProcess( + target=_register_model_with_run_id_multiprocess, + kwargs={ + 'mlflow_logger': + mlflow_logger, + 'composer_logging_level': + logging.getLogger('composer').level, + 'model_uri': + local_save_path, + 'name': + self.mlflow_registered_model_name, + 'await_creation_for': + 3600, + }) + process.start() + self.child_processes.append(process) + + # Save the temporary directory to be cleaned up later. + if use_temp_dir: + self.temp_save_dir = temp_save_dir + else: + # Clean up the temporary directory if we don't need to register to mlflow. + if use_temp_dir: + shutil.rmtree(temp_save_dir) + dist.barrier() diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 9f4af3099e..9696f967ca 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -7,6 +7,7 @@ import numpy as np import torch +from composer.utils import dist from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase @@ -315,6 +316,8 @@ def auto_packing_ratio(dataloader_cfg: DictConfig, """ from composer.utils import dist, get_device, reproducibility + log.debug('Searching for optimal packing ratio.') + # Stash the rng state to restore later. rng_state = reproducibility.get_rng_state() # Set the seed so that auto packing is deterministic. @@ -388,8 +391,19 @@ def profile_packing( dataloader_cfg.persistent_workers = False # If streaming dataset, use a temporary local folder for profiling + local_rank_zero = dist.get_global_rank() - dist.get_local_rank() if dataloader_cfg.dataset.get('remote') is not None: - dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name + tmp_path_to_broadcast = tempfile.TemporaryDirectory().name + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + tmp_path = gathered_paths[local_rank_zero] + dataloader_cfg.dataset.local = tmp_path + + if dataloader_cfg.dataset.get('streams') is not None: + for stream_config in dataloader_cfg.dataset.streams.values(): + tmp_path_to_broadcast = tempfile.TemporaryDirectory().name + gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) + tmp_path = gathered_paths[local_rank_zero] + stream_config.local = tmp_path # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] @@ -447,6 +461,12 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]: waste_percent = 100 * packer.waste return padding_percent, waste_percent - for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes): + log.debug('Profiling packing ratios') + total_packing_ratios = min(len(packing_ratios), len(raw_batch_sizes)) + for i, (packing_ratio, + raw_batch_size) in enumerate(zip(packing_ratios, raw_batch_sizes)): + log.debug( + f'Progress [{i}/{total_packing_ratios}]: Profiling packing ratio {packing_ratio}' + ) padding, waste = profile(raw_batch_size) yield (packing_ratio, padding, waste) diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py new file mode 100644 index 0000000000..9c7dabe128 --- /dev/null +++ b/llmfoundry/layers_registry.py @@ -0,0 +1,20 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Type + +import torch + +from llmfoundry.utils.registry_utils import create_registry + +# Layers +_norm_description = """The norms registry is used to register classes that implement normalization layers.""" +norms = create_registry('llmfoundry', + 'norms', + generic_type=Type[torch.nn.Module], + entry_points=True, + description=_norm_description) + +__all__ = [ + 'norms', +] diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 38ed7a7e70..5bca5cb21a 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -11,8 +11,8 @@ from composer.models.huggingface import peft_installed from composer.utils import dist from omegaconf import DictConfig -from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel, - PreTrainedTokenizerBase) +from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig, + PreTrainedModel, PreTrainedTokenizerBase) from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS, DEFAULT_CAUSAL_LM_TRAIN_METRICS) @@ -161,6 +161,18 @@ def _autoset_attn_implementation_monkeypatch( elif attr is None and isinstance(v, Mapping): setattr(config, k, {}) getattr(config, k).update(v) + elif isinstance(attr, PretrainedConfig): + if not isinstance(v, Mapping): + raise ValueError( + f'Expected a dictionary for config override {k}, but got {v}.' + ) + + for _k, _v in v.items(): + if not hasattr(attr, _k): + raise ValueError( + f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).' + ) + setattr(attr, _k, _v) else: setattr(config, k, v) diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index df4216b81c..262f190b47 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -9,7 +9,7 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm +from llmfoundry.models.layers.norm import LPLayerNorm __all__ = [ 'scaled_multihead_dot_product_attention', @@ -23,7 +23,6 @@ 'ATTN_CLASS_REGISTRY', 'MPTMLP', 'MPTBlock', - 'NORM_CLASS_REGISTRY', 'LPLayerNorm', 'FC_CLASS_REGISTRY', 'SharedEmbedding', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 1deca69eb2..c24b3d4afa 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -15,7 +15,7 @@ from torch import nn from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_norm def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -419,12 +419,19 @@ def __init__( self.Wqkv._fused = (0, fuse_splits) if self.qk_ln or self.qk_gn: - norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] norm_size = self.head_dim if qk_gn else d_model - self.q_ln = norm_class(norm_size, device=device) + self.q_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) if qk_ln: norm_size = self.head_dim * kv_n_heads - self.k_ln = norm_class(norm_size, device=device) + self.k_ln = build_norm( + name=norm_type.lower(), + normalized_shape=norm_size, + device=device, + ) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn @@ -501,9 +508,10 @@ def forward( value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) elif rotary_emb_w_meta_info['impl'] == 'hf': if is_transformers_version_gte('4.38'): - (cos, sin) = rotary_emb(x=value, - position_ids=offset_info, - seq_len=None) + (cos, sin) = rotary_emb( + x=value, + position_ids=offset_info, + ) else: (cos, sin) = rotary_emb(x=value, seq_len=seq_len) if is_transformers_version_gte('4.38'): diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 855df7903f..42feb983d4 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -10,7 +10,7 @@ from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.layer_builders import build_norm try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip @@ -72,7 +72,6 @@ def __init__( del kwargs # unused, just to capture any extra args from the config super().__init__() - norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] assert isinstance(attn_config['attn_type'], str) attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] @@ -88,7 +87,11 @@ def __init__( if k not in args_to_exclude_in_attn_class } - self.norm_1 = norm_class(d_model, device=device) + self.norm_1 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) self.attn = attn_class( d_model=d_model, n_heads=n_heads, @@ -100,7 +103,11 @@ def __init__( self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False): - self.norm_2 = norm_class(d_model, device=device) + self.norm_2 = build_norm( + name=norm_type.lower(), + normalized_shape=d_model, + device=device, + ) self.ffn = build_ffn( d_model=d_model, expansion_ratio=expansion_ratio, diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py new file mode 100644 index 0000000000..23f5b89668 --- /dev/null +++ b/llmfoundry/models/layers/layer_builders.py @@ -0,0 +1,25 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Union + +import torch + +from llmfoundry.layers_registry import norms +from llmfoundry.utils.registry_utils import construct_from_registry + + +def build_norm( + name: str, + normalized_shape: Union[int, List[int], torch.Size], + device: Optional[str] = None, +): + kwargs = { + 'normalized_shape': normalized_shape, + 'device': device, + } + + return construct_from_registry(name=name, + registry=norms, + pre_validation_function=torch.nn.Module, + kwargs=kwargs) diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index be4f50f521..92d295c71c 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -1,10 +1,14 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Optional, Type, Union +from typing import List, Optional, Union import torch +from llmfoundry.layers_registry import norms + +norms.register(name='layernorm', func=torch.nn.LayerNorm) + def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor: if torch.is_autocast_enabled(): @@ -18,6 +22,7 @@ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor: return tensor +@norms.register_class('low_precision_layernorm') class LPLayerNorm(torch.nn.LayerNorm): def __init__( @@ -62,6 +67,7 @@ def rms_norm(x: torch.Tensor, return output +@norms.register_class('rmsnorm') class RMSNorm(torch.nn.Module): def __init__( @@ -84,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) +@norms.register_class('low_precision_rmsnorm') class LPRMSNorm(RMSNorm): def __init__( @@ -111,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.eps).to(dtype=x.dtype) +@norms.register_class('triton_rmsnorm') class TritonRMSNorm(torch.nn.Module): def __init__( @@ -150,12 +158,3 @@ def forward(self, x: torch.Tensor): prenorm=False, residual_in_fp32=False, ) - - -NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = { - 'layernorm': torch.nn.LayerNorm, - 'low_precision_layernorm': LPLayerNorm, - 'rmsnorm': RMSNorm, - 'low_precision_rmsnorm': LPRMSNorm, - 'triton_rmsnorm': TritonRMSNorm, -} diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 20c3850a82..2f58ea312e 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -19,6 +19,9 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note) +from llmfoundry.layers_registry import norms # type: ignore (see note) +from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note) ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 183e1b24f6..d54b797269 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -20,7 +20,6 @@ from composer.utils import dist from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above @@ -42,11 +41,13 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding +from llmfoundry.layers_registry import norms from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.ffn import build_ffn as build_ffn +from llmfoundry.models.layers.layer_builders import build_norm from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -297,12 +298,11 @@ def __init__(self, config: MPTConfig): else: config.init_device = 'meta' - if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) + if config.norm_type.lower() not in norms.get_all(): + norm_options = ' | '.join(norms.get_all()) raise NotImplementedError( f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).' ) - norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414) # both report this helping with stabilizing training @@ -329,7 +329,11 @@ def __init__(self, config: MPTConfig): block.max_block_idx = config.n_layers - 1 pass_on_block_idx(block) - self.norm_f = norm_class(config.d_model, device=config.init_device) + self.norm_f = build_norm( + name=config.norm_type.lower(), + normalized_shape=config.d_model, + device=config.init_device, + ) self.rope = config.attn_config['rope'] self.rope_impl = None diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py index 9acd7dd11c..bde7c92bd7 100644 --- a/llmfoundry/models/utils/act_ckpt.py +++ b/llmfoundry/models/utils/act_ckpt.py @@ -5,10 +5,10 @@ import torch +from llmfoundry.layers_registry import norms from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY def pass_on_block_idx(parent: torch.nn.Module): @@ -29,12 +29,12 @@ def get_act_ckpt_module(mod_name: str) -> Any: mod_type = ATTN_CLASS_REGISTRY[mod_name] elif mod_name in FFN_CLASS_REGISTRY: mod_type = FFN_CLASS_REGISTRY[mod_name] - elif mod_name in NORM_CLASS_REGISTRY: - mod_type = NORM_CLASS_REGISTRY[mod_name] + elif mod_name in norms: + mod_type = norms.get(mod_name) else: msg = ', '.join( list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + - list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + list(norms.get_all()) + ['MPTBlock']) raise ValueError( f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' ) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 2e72ccfa47..35dc88a408 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -10,8 +10,8 @@ import torch from torch import nn +from llmfoundry.layers_registry import norms from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY try: import transformer_engine.pytorch as te @@ -129,7 +129,8 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): + elif isinstance(module, + tuple(set([norms.get(name) for name in norms.get_all()]))): # Norm if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index e289a923b6..424075da3b 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -12,6 +12,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.interfaces import CallbackWithConfig +from llmfoundry.layers_registry import norms from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( @@ -119,4 +120,5 @@ 'models', 'metrics', 'dataloaders', + 'norms', ] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index fe803d62db..a8c660df70 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -104,6 +104,7 @@ def build_eval_loaders( # Load the eval data to fail fast. metrics will get added # later in add_metrics_to_eval_loaders, after the model is loaded metric_names=[], + device_eval_microbatch_size=device_eval_batch_size, ) evaluators.append(eval_loader) return evaluators diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py index 7089996a13..0901ea198a 100644 --- a/llmfoundry/utils/registry_utils.py +++ b/llmfoundry/utils/registry_utils.py @@ -14,6 +14,7 @@ __all__ = ['TypedRegistry', 'create_registry', 'construct_from_registry'] T = TypeVar('T') +TypeBoundT = TypeVar('TypeBoundT', bound=Type) class TypedRegistry(catalogue.Registry, Generic[T]): @@ -36,6 +37,12 @@ def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]: def register(self, name: str, *, func: Optional[T] = None) -> T: return super().register(name, func=func) + def register_class(self, + name: str, + *, + func: Optional[TypeBoundT] = None) -> TypeBoundT: + return super().register(name, func=func) + def get(self, name: str) -> T: return super().get(name) diff --git a/mcli/mcli-1b-eval.yaml b/mcli/mcli-1b-eval.yaml index f577369cfa..51511838d6 100644 --- a/mcli/mcli-1b-eval.yaml +++ b/mcli/mcli-1b-eval.yaml @@ -1,15 +1,15 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e .[gpu] + pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo command: | cd llm-foundry/scripts/ composer eval/eval.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest name: mpt-1b-eval compute: diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index af0f7c356d..33a891c058 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -1,9 +1,9 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e .[gpu-flash2] + pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo # We are fetching, converting, and training on the 'val' split @@ -17,7 +17,7 @@ command: | --out_root ./my-copy-c4 --splits train_small val_small \ --concat_tokens 8192 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest name: mpt-1b-ctx-8k-gpus-8 compute: diff --git a/mcli/mcli-1b.yaml b/mcli/mcli-1b.yaml index 00693816fb..0eeec4e652 100644 --- a/mcli/mcli-1b.yaml +++ b/mcli/mcli-1b.yaml @@ -1,9 +1,9 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e .[gpu-flash2] + pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo # We are fetching, converting, and training on the 'val' split @@ -21,7 +21,7 @@ command: | eval_loader.dataset.split=val_small \ max_duration=100ba \ eval_interval=0 -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest name: mpt-1b-gpus-8 compute: diff --git a/mcli/mcli-benchmark-mpt.yaml b/mcli/mcli-benchmark-mpt.yaml index efe31568a4..364df79ad0 100644 --- a/mcli/mcli-benchmark-mpt.yaml +++ b/mcli/mcli-benchmark-mpt.yaml @@ -6,14 +6,14 @@ compute: # cluster: TODO # Name of the cluster to use for this run # gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: ".[gpu-flash2]" + pip_install: .[gpu] command: | cd llm-foundry/scripts/inference/benchmarking diff --git a/mcli/mcli-convert-composer-to-hf.yaml b/mcli/mcli-convert-composer-to-hf.yaml index 05527299d8..0e20dc8748 100644 --- a/mcli/mcli-convert-composer-to-hf.yaml +++ b/mcli/mcli-convert-composer-to-hf.yaml @@ -1,9 +1,9 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e . + pip_install: . ssh_clone: false # Should be true if using a private repo command: | @@ -13,7 +13,7 @@ command: | --hf_output_path s3://bucket/folder/hf/ \ --output_precision bf16 \ -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest name: convert-composer-hf compute: diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml index bed2627797..0dbeba3a41 100644 --- a/mcli/mcli-hf-eval.yaml +++ b/mcli/mcli-hf-eval.yaml @@ -1,9 +1,9 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e ".[gpu-flash2]" + pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo command: | @@ -16,7 +16,7 @@ gpu_num: 8 # gpu_type: # cluster: # replace with your cluster here! -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: diff --git a/mcli/mcli-hf-generate.yaml b/mcli/mcli-hf-generate.yaml index 3ae2b6558a..70a421cc57 100644 --- a/mcli/mcli-hf-generate.yaml +++ b/mcli/mcli-hf-generate.yaml @@ -1,9 +1,9 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e .[gpu-flash2] + pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo command: | @@ -35,7 +35,7 @@ command: | "Here's a quick recipe for baking chocolate chip cookies: Start by" \ "The best 5 cities to visit in Europe are" -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest name: hf-generate compute: diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 4f4d994eb1..6076e02ef0 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -1,15 +1,15 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e .[gpu-flash2] + pip_install: .[gpu] ssh_clone: false # Should be true if using a private repo command: | cd llm-foundry/scripts composer train/train.py /mnt/config/parameters.yaml -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest name: llama2-finetune compute: diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index f774c92560..4046cffbca 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,9 +1,9 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: -e ".[gpu-flash2,openai]" + pip_install: .[gpu,openai] ssh_clone: false # Should be true if using a private repo command: | @@ -16,7 +16,7 @@ gpu_num: # gpu_type: # cluster: # replace with your cluster here! -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest # The below is injected as a YAML file: /mnt/config/parameters.yaml parameters: diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index 5de105d9ba..bfc1bfa858 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -1,5 +1,5 @@ name: c4-2k-pre-tokenized -image: mosaicml/llm-foundry:2.1.0_cu121_flash2-latest +image: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest compute: gpus: 8 # Number of GPUs to use @@ -14,9 +14,9 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.6.0 + git_branch: v0.7.0 # git_commit: # OR use your commit hash - pip_install: "." + pip_install: . ssh_clone: false # Should be true if using a private repo command: | diff --git a/scripts/train/train.py b/scripts/train/train.py index 3244a7ecb9..b98dd2680c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -197,9 +197,9 @@ def main(cfg: DictConfig) -> Trainer: if max_split_size_mb is not None: cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') - # Expandeable segments - if cfg.pop('expandeable_segments', False): - cuda_alloc_conf.append('expandeable_segments:True') + # Expandable segments + if cfg.pop('expandable_segments', False): + cuda_alloc_conf.append('expandable_segments:True') if len(cuda_alloc_conf) > 0: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ','.join(cuda_alloc_conf) diff --git a/scripts/train/yamls/finetune/dbrx-full-ft.yaml b/scripts/train/yamls/finetune/dbrx-full-ft.yaml index 8933a2d3c8..a0e2504787 100644 --- a/scripts/train/yamls/finetune/dbrx-full-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-full-ft.yaml @@ -88,7 +88,7 @@ device_eval_batch_size: 1 precision: amp_bf16 autoresume: true dist_timeout: 3600 -expandeable_segments: true +expandable_segments: true # FSDP fsdp_config: diff --git a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml index 58ca8fae61..7fb921ae16 100644 --- a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml @@ -96,7 +96,7 @@ device_eval_batch_size: 1 precision: amp_bf16 autoresume: true dist_timeout: 3600 -expandeable_segments: true +expandable_segments: true # FSDP fsdp_config: diff --git a/setup.py b/setup.py index 22b7cb17ca..79511eeca3 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ 'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22', 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.38.2,<4.39', + 'transformers>=4.39.3,<4.40', 'mosaicml-streaming>=0.7.4,<0.8', 'torch>=2.2.1,<2.3', 'datasets>=2.16,<2.17', @@ -67,7 +67,7 @@ 'onnx==1.14.0', 'onnxruntime==1.15.1', 'boto3>=1.21.45,<2', - 'huggingface-hub>=0.17.0,<1.0', + 'huggingface-hub>=0.19.0,<1.0', 'beautifulsoup4>=4.12.2,<5', # required for model download utils 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 3949c091aa..7b4ef1e058 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -7,13 +7,13 @@ import pathlib import shutil from argparse import Namespace -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Dict, Optional, cast from unittest.mock import ANY, MagicMock, patch import pytest import torch import transformers -from composer import Trainer +from composer import ComposerModel, Trainer from composer.loggers import MLFlowLogger from composer.utils import dist, get_device from omegaconf import DictConfig @@ -29,6 +29,14 @@ from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset +_OPTIMIZER_CFG = lambda: { + 'name': 'decoupled_adamw', + 'lr': 6e-4, + 'betas': [0.9, 0.95], + 'eps': 1e-8, + 'weight_decay': 0.0, +} + def _save_model_mock(*args: Any, path: str, **kwargs: Any): os.makedirs(path, exist_ok=True) @@ -256,12 +264,34 @@ def test_callback_inits(): assert hf_checkpointer.mlflow_logging_config['task'] == 'llm/v1/completions' +class MockSpawnProcess: + """Class for mocking `multiprocessing.context.SpawnProcess`. + + Runs `target(**kwargs)` on the main process. + + Mock classes are not picklable and therefore cannot be used with + multiprocessing, so we need to patch SpawnProcess for tests. + """ + + def __init__(self, target: Callable, kwargs: Dict[str, Any]): + self.target = target + self.kwargs = kwargs + + def start(self): + self.target(**self.kwargs) + + def is_alive(self) -> bool: + return False + + @pytest.mark.gpu @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) @patch('os.cpu_count', MagicMock(return_value=1)) +@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess) def test_huggingface_conversion_callback_interval( tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, save_interval: str, max_duration: str, expected_hf_checkpoints: int, @@ -287,13 +317,7 @@ def test_huggingface_conversion_callback_interval( original_model = build_tiny_mpt() - optimizer_config = { - 'name': 'decoupled_adamw', - 'lr': 6e-4, - 'betas': [0.9, 0.95], - 'eps': 1e-8, - 'weight_decay': 0.0, - } + optimizer_config = _OPTIMIZER_CFG() optimizer_name = optimizer_config.pop('name') optimizer = build_optimizer(original_model, optimizer_name, optimizer_config) @@ -378,66 +402,8 @@ def test_huggingface_conversion_callback_interval( delete_transformers_cache() -@pytest.mark.world_size(2) -@pytest.mark.gpu -@pytest.mark.parametrize( - 'model,tie_word_embeddings,peft_config', - [ - ('mpt', True, None), - ('mpt', False, None), - ('neo', None, None), - ('llama2', None, None), - ('llama2', None, { - 'peft_type': 'LORA', - 'task_type': 'CAUSAL_LM', - 'lora_alpha': 32, - 'lora_dropout': 0.05, - 'r': 16, - 'target_modules': [ - 'q_proj', - 'k_proj', - 'v_proj', - ], - }), - ], -) -@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) -@pytest.mark.parametrize( - 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', - [('1ba', '1ba', '1ba', 1, 1)]) -@patch('os.cpu_count', MagicMock(return_value=1)) -def test_huggingface_conversion_callback( - model: str, - tmp_path: pathlib.Path, - tie_word_embeddings: bool, - fsdp_state_dict_type: Optional[str], - hf_save_interval: str, - save_interval: str, - max_duration: str, - expected_hf_checkpoints: int, - expected_normal_checkpoints: int, - peft_config: Optional[dict], -): - delete_transformers_cache() - - dist.initialize_dist(get_device('gpu')) - - max_seq_len = 16 - device_batch_size = 1 - dataset_size = 2 - precision_str = 'bfloat16' - precision = torch.bfloat16 - batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2)) - - checkpointer_callback = HuggingFaceCheckpointer( - save_folder=os.path.join(tmp_path, 'checkpoints'), - save_interval=hf_save_interval, - precision=precision_str, - mlflow_registered_model_name='dummy-registered-name') - - # get small version of each model - model_cfg = None - tokenizer_name = None +def _get_model_and_tokenizer(model: str, max_seq_len: int, + tie_word_embeddings: bool): if model == 'mpt': model_cfg = { 'name': 'mpt_causal_lm', @@ -489,12 +455,43 @@ def test_huggingface_conversion_callback( tokenizer_name = 'meta-llama/Llama-2-7b-hf' else: raise ValueError(f'Unknown model {model}') - assert model_cfg is not None - assert tokenizer_name is not None - model_cfg = om.create(model_cfg) - if peft_config is not None: - model_cfg['peft_config'] = peft_config + return model_cfg, tokenizer_name + + +def _assert_mlflow_logger_calls(mlflow_logger_mock: MagicMock, + peft_config: Optional[dict] = None): + if dist.get_global_rank() == 0: + assert mlflow_logger_mock.save_model.call_count == 1 + if peft_config is not None: + expectation = { + 'flavor': 'peft', + 'path': ANY, + 'save_pretrained_dir': ANY, + 'metadata': {}, + } + else: + import numpy as np + + default_input_example = { + 'prompt': np.array(['What is Machine Learning?']) + } + expectation = { + 'flavor': 'transformers', + 'transformers_model': ANY, + 'path': ANY, + 'task': 'llm/v1/completions', + 'input_example': default_input_example, + 'metadata': {} + } + mlflow_logger_mock.save_model.assert_called_with(**expectation) + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + else: + assert mlflow_logger_mock.log_model.call_count == 0 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 + + +def _get_fsdp_config(fsdp_state_dict_type: Optional[str]): fsdp_config = { 'sharding_strategy': 'FULL_SHARD', 'mixed_precision': 'PURE', @@ -504,12 +501,10 @@ def test_huggingface_conversion_callback( 'limit_all_gathers': True, 'state_dict_type': fsdp_state_dict_type, } + return fsdp_config - tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') - tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl') - if dist.get_global_rank() == 0: - make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) +def _get_dataloader_cfg(tiny_dataset_folder_path: str, max_seq_len: int): dataloader_cfg = { 'name': 'finetuning', 'dataset': { @@ -528,6 +523,198 @@ def test_huggingface_conversion_callback( 'persistent_workers': False, 'timeout': 0 } + return dataloader_cfg + + +def _assert_checkpoint_equivalence(tmp_path: pathlib.Path, + expected_normal_checkpoints: int, + expected_hf_checkpoints: int, + trainer: Trainer, + batches_per_epoch: int, + precision: torch.dtype, + model: str, + tokenizer: PreTrainedTokenizerBase, + original_model: ComposerModel, + fsdp_state_dict_type: Optional[str] = None, + peft_config: Optional[dict] = None): + """Asserts the equivalence of checkpoints. + + Asserts equivalence of checkpoints between the original mpt model and the converted hf model. + + Args: + tmp_path (str): The path to the temporary directory where the checkpoints are saved. + expected_normal_checkpoints (int): The expected number of normal checkpoints. + expected_hf_checkpoints (int): The expected number of HuggingFace checkpoints. + trainer (Trainer): The trainer object used for training the model. + batches_per_epoch (int): The number of batches per epoch. + precision (torch.dtype): The precision of the model. + model (str): The type of model ('mpt', 'neo', or 'llama2'). + tokenizer (PreTrainedTokenizerBase): The model tokenizer. + original_model (ComposerModel): The original model object. + fsdp_state_dict_type (Optional[str], optional): The type of FSDP state dict. Defaults to None. + peft_config (Optional[dict], optional): The PEFT configuration. Defaults to None. + """ + loaded_model = None + loaded_tokenizer = None + # Only rank zero is saving the huggingface checkpoints, so only check + # for equivalence on rank zero + if dist.get_global_rank() == 0: + normal_checkpoints = [ + name for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) + if name != 'huggingface' + ] + huggingface_checkpoints = [ + name for name in os.listdir( + os.path.join(tmp_path, 'checkpoints', 'huggingface')) + ] + + checkpoint_files = os.listdir( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + huggingface_checkpoints[-1])) + if peft_config is not None: + assert 'adapter_config.json' in checkpoint_files + assert 'adapter_model.safetensors' in checkpoint_files + + assert len(normal_checkpoints) == expected_normal_checkpoints + assert len(huggingface_checkpoints) == expected_hf_checkpoints + + # Patch flash_attn package to be empty to simulate loading the model in + # an environment without flash attention installed + with patch.dict('sys.modules', {'flash_attn': None}): + if peft_config is not None: + composer_model = trainer.state.model.module if trainer.state.is_model_ddp else trainer.state.model + composer_model.model.base_model.save_pretrained(tmp_path / + 'base-model') + + checkpoint_path = os.path.join(tmp_path, 'checkpoints', + 'huggingface', + f'ba{batches_per_epoch}') + + if peft_config is not None: + with open(os.path.join(checkpoint_path, + 'adapter_config.json')) as _f: + adapter_config = json.load(_f) + + adapter_config['base_model_name_or_path'] = str(tmp_path / + 'base-model') + + with open(os.path.join(checkpoint_path, 'adapter_config.json'), + 'w') as _f: + json.dump(adapter_config, _f) + + # Load the last huggingface checkpoint + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=True, + ) + + # Check that the loaded model has the correct precision, and then set it back + # to the original for the equivalence check + if peft_config is None: + assert loaded_model.config.torch_dtype == precision + loaded_model.config.torch_dtype = original_model.model.config.torch_dtype + + if model == 'mpt': + # Check that we have correctly set these attributes, and then set them back + # to the original for the equivalence check + assert loaded_model.config.attn_config['attn_impl'] == 'torch' + assert loaded_model.config.init_device == 'cpu' + loaded_model.config.attn_config[ + 'attn_impl'] = original_model.model.config.attn_config[ + 'attn_impl'] + loaded_model.config.init_device = original_model.model.config.init_device + + loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba{batches_per_epoch}'), + trust_remote_code=True, + ) + + check_hf_model_equivalence( + trainer.state.model.model.to(precision) if fsdp_state_dict_type + is not None else trainer.state.model.module.model.to(precision), + loaded_model, + just_lora=peft_config is not None) + check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer) + + +@pytest.mark.world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize( + 'model,tie_word_embeddings,peft_config', + [ + ('mpt', True, None), + ('mpt', False, None), + ('neo', None, None), + ('llama2', None, None), + ('llama2', None, { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'lora_dropout': 0.05, + 'r': 16, + 'target_modules': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + }), + ], +) +@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) +@pytest.mark.parametrize( + 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', + [('1ba', '1ba', '1ba', 1, 1)]) +@patch('os.cpu_count', MagicMock(return_value=1)) +@patch('llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess) +def test_huggingface_conversion_callback( + model: str, + tmp_path: pathlib.Path, + tie_word_embeddings: bool, + fsdp_state_dict_type: Optional[str], + hf_save_interval: str, + save_interval: str, + max_duration: str, + expected_hf_checkpoints: int, + expected_normal_checkpoints: int, + peft_config: Optional[dict], +): + delete_transformers_cache() + + dist.initialize_dist(get_device('gpu')) + + max_seq_len = 16 + device_batch_size = 1 + dataset_size = 2 + precision_str = 'bfloat16' + precision = torch.bfloat16 + batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2)) + + checkpointer_callback = HuggingFaceCheckpointer( + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval=hf_save_interval, + precision=precision_str, + mlflow_registered_model_name='dummy-registered-name') + + # get small version of each model + model_cfg, tokenizer_name = _get_model_and_tokenizer( + model, max_seq_len, tie_word_embeddings) + assert model_cfg is not None + assert tokenizer_name is not None + model_cfg = om.create(model_cfg) + if peft_config is not None: + model_cfg['peft_config'] = peft_config + + fsdp_config = _get_fsdp_config(fsdp_state_dict_type) + optimizer_config = _OPTIMIZER_CFG() + + tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') + tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl') + if dist.get_global_rank() == 0: + make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) + + dataloader_cfg = _get_dataloader_cfg(tiny_dataset_folder_path, max_seq_len) dataloader_cfg = om.create(dataloader_cfg) @@ -542,19 +729,8 @@ def test_huggingface_conversion_callback( device_batch_size, ) - original_model = build_composer_model( - name=model_cfg['name'], - cfg=model_cfg, - tokenizer=tokenizer, - ) - - optimizer_config = { - 'name': 'decoupled_adamw', - 'lr': 6e-4, - 'betas': [0.9, 0.95], - 'eps': 1e-8, - 'weight_decay': 0.0, - } + original_model = build_composer_model(model_cfg['name'], model_cfg, + tokenizer) optimizer_name = optimizer_config.pop('name') optimizer = build_optimizer(original_model, optimizer_name, optimizer_config) @@ -581,126 +757,25 @@ def test_huggingface_conversion_callback( ) trainer.fit() - if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == 1 - if peft_config is not None: - expectation = { - 'flavor': 'peft', - 'path': ANY, - 'save_pretrained_dir': ANY, - 'metadata': {}, - } - else: - import numpy as np - - default_input_example = { - 'prompt': np.array(['What is Machine Learning?']) - } - - expectation = { - 'flavor': 'transformers', - 'transformers_model': ANY, - 'path': ANY, - 'task': 'llm/v1/completions', - 'input_example': default_input_example, - 'metadata': {} - } - mlflow_logger_mock.save_model.assert_called_with(**expectation) - assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 - else: - assert mlflow_logger_mock.log_model.call_count == 0 - assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 + _assert_mlflow_logger_calls(mlflow_logger_mock, peft_config) # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP with FSDP.summon_full_params(trainer.state.model, writeback=False, recurse=True): - loaded_model = None - loaded_tokenizer = None - # Only rank zero is saving the huggingface checkpoints, so only check - # for equivalence on rank zero - if dist.get_global_rank() == 0: - normal_checkpoints = [ - name - for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) - if name != 'huggingface' - ] - huggingface_checkpoints = [ - name for name in os.listdir( - os.path.join(tmp_path, 'checkpoints', 'huggingface')) - ] - - checkpoint_files = os.listdir( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - huggingface_checkpoints[-1])) - if peft_config is not None: - assert 'adapter_config.json' in checkpoint_files - assert 'adapter_model.safetensors' in checkpoint_files - - assert len(normal_checkpoints) == expected_normal_checkpoints - assert len(huggingface_checkpoints) == expected_hf_checkpoints - - # Patch flash_attn package to be empty to simulate loading the model in - # an environment without flash attention installed - with patch.dict('sys.modules', {'flash_attn': None}): - if peft_config is not None: - composer_model = trainer.state.model.module if trainer.state.is_model_ddp else trainer.state.model - composer_model.model.base_model.save_pretrained( - tmp_path / 'base-model') - - checkpoint_path = os.path.join(tmp_path, 'checkpoints', - 'huggingface', - f'ba{batches_per_epoch}') - - if peft_config is not None: - with open( - os.path.join(checkpoint_path, - 'adapter_config.json')) as _f: - adapter_config = json.load(_f) - - adapter_config['base_model_name_or_path'] = str( - tmp_path / 'base-model') - - with open( - os.path.join(checkpoint_path, - 'adapter_config.json'), 'w') as _f: - json.dump(adapter_config, _f) - - # Load the last huggingface checkpoint - loaded_model = transformers.AutoModelForCausalLM.from_pretrained( - checkpoint_path, - trust_remote_code=True, - ) - - # Check that the loaded model has the correct precision, and then set it back - # to the original for the equivalence check - if peft_config is None: - assert loaded_model.config.torch_dtype == precision - loaded_model.config.torch_dtype = original_model.model.config.torch_dtype - - if model == 'mpt': - # Check that we have correctly set these attributes, and then set them back - # to the original for the equivalence check - assert loaded_model.config.attn_config['attn_impl'] == 'torch' - assert loaded_model.config.init_device == 'cpu' - loaded_model.config.attn_config[ - 'attn_impl'] = original_model.model.config.attn_config[ - 'attn_impl'] - loaded_model.config.init_device = original_model.model.config.init_device - - loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( - os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{batches_per_epoch}'), - trust_remote_code=True, - ) - - check_hf_model_equivalence( - trainer.state.model.model.to(precision) if fsdp_state_dict_type - is not None else trainer.state.model.module.model.to(precision), - loaded_model, - just_lora=peft_config is not None) - check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer) + _assert_checkpoint_equivalence( + tmp_path=tmp_path, + expected_normal_checkpoints=expected_normal_checkpoints, + expected_hf_checkpoints=expected_hf_checkpoints, + trainer=trainer, + batches_per_epoch=batches_per_epoch, + original_model=original_model, + precision=precision, + model=model, + tokenizer=tokenizer, + fsdp_state_dict_type=fsdp_state_dict_type, + peft_config=peft_config) dist.barrier() delete_transformers_cache() diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index d5de596199..e79756aba3 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -12,7 +12,7 @@ import torch from omegaconf import DictConfig from omegaconf import OmegaConf as om -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PretrainedConfig from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -205,3 +205,34 @@ def test_rope_scaling_override(): # This would error if the config isn't parsed into a proper dictionary model.get_metadata() assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5} + + +@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ, + reason='CI does not have access to Dbrx') +def test_nested_override(): + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'databricks/dbrx-instruct', + 'config_overrides': { + 'ffn_config': { + 'ffn_hidden_size': 500, + } + }, + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'meta', + } + model_cfg = om.create(model_cfg) + + model = build_composer_model( + name=model_cfg.name, + cfg=model_cfg, + tokenizer=None, # type: ignore + ) + + # The value we changed + assert model.config.ffn_config.ffn_hidden_size == 500 + # Ensure we still have a config, and haven't replaced it with a dictionary + assert isinstance(model.config.ffn_config, PretrainedConfig) + # Ensure the other values still exist and are not set back to their defaults + assert model.config.ffn_config.moe_num_experts == 16 diff --git a/tests/models/test_model.py b/tests/models/test_model.py index c5f6062b0e..7bd8292151 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -26,8 +26,9 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor from llmfoundry import ComposerHFCausalLM +from llmfoundry.layers_registry import norms from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP -from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias +from llmfoundry.models.layers import build_alibi_bias from llmfoundry.models.layers.attention import (check_alibi_support, is_flash_v2_installed) from llmfoundry.models.layers.blocks import MPTBlock @@ -682,7 +683,7 @@ def test_lora_id(): assert isinstance(model.model, peft.PeftModelForCausalLM) -@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) +@pytest.mark.parametrize('norm_type', norms.get_all()) @pytest.mark.parametrize('no_bias', [False, True]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) @pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ diff --git a/tests/models/test_rmsnorm_triton_vs_eager.py b/tests/models/test_rmsnorm_triton_vs_eager.py index 1902f46d78..7169c5d926 100644 --- a/tests/models/test_rmsnorm_triton_vs_eager.py +++ b/tests/models/test_rmsnorm_triton_vs_eager.py @@ -8,6 +8,7 @@ from composer.core.precision import get_precision_context from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.layer_builders import build_norm @pytest.mark.gpu @@ -19,17 +20,18 @@ def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]], pytest.skip( 'triton implementation of rmsnorm requires flash attention 2.') - from llmfoundry.models.layers import norm - batch_size = 2 - cfg = { - 'normalized_shape': normalized_shape, - 'device': device, - } - - eager_rmsnorm = norm.NORM_CLASS_REGISTRY['rmsnorm'](**cfg) - triton_rmsnorm = norm.NORM_CLASS_REGISTRY['triton_rmsnorm'](**cfg) + eager_rmsnorm = build_norm( + name='rmsnorm', + normalized_shape=normalized_shape, + device=device, + ) + triton_rmsnorm = build_norm( + name='triton_rmsnorm', + normalized_shape=normalized_shape, + device=device, + ) triton_rmsnorm.load_state_dict(eager_rmsnorm.state_dict()) diff --git a/tests/test_registry.py b/tests/test_registry.py index 30f6e0e38f..c93c7c9749 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -30,6 +30,7 @@ def test_expected_registries_exist(): 'dataloaders', 'metrics', 'models', + 'norms', } assert existing_registries == expected_registry_names